- 
                Notifications
    You must be signed in to change notification settings 
- Fork 292
[Carver][Bugfix] Correct score function for warp tile selection in tensorcore policy #724
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -187,7 +187,7 @@ def get_configs(args, kwargs): | |
| from tilelang.carver.roller.rasterization import NoRasterization | ||
| import torch | ||
|  | ||
| arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip") | ||
| arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change corrects a critical bug in the architecture detection logic. The previous implementation had  | ||
| topk = 10 | ||
|  | ||
| carve_template = MatmulTemplate( | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -96,7 +96,7 @@ def conv( | |
|  | ||
| def main(): | ||
| my_func = kernel(N, C, H, W, F, K, S, D, P, 64, 128, 32, 3, 256) | ||
| cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip") | ||
| cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip") | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change corrects a critical bug in the architecture detection logic. The previous implementation had  | ||
| result = Analyzer.analysis(my_func, cuda_device) | ||
| print(result) | ||
| print(f"Analyzed FLOPs: {result.total_flops}") | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -49,7 +49,7 @@ def matmul( | |
| def main(): | ||
| my_func = kernel(128, 128, 32, 3, 128, True) | ||
|  | ||
| cuda_device = CDNA("cuda") if torch.version.hip is None else CUDA("hip") | ||
| cuda_device = CUDA("cuda") if torch.version.hip is None else CDNA("hip") | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change corrects a critical bug in the architecture detection logic. The previous implementation had  | ||
| result = Analyzer.analysis(my_func, cuda_device) | ||
|  | ||
| print(f"Analyzed FLOPs: {result.total_flops}") | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -16,7 +16,7 @@ def ref_program(A, B): | |
|  | ||
| def get_configs(M, N, K, with_roller=False, topk=20): | ||
| if with_roller: | ||
| arch = CDNA("cuda") if torch.version.hip is None else CUDA("hip") | ||
| arch = CUDA("cuda") if torch.version.hip is None else CDNA("hip") | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change corrects a critical bug in the architecture detection logic. The previous implementation had  | ||
| carve_template = MatmulTemplate( | ||
| M=M, | ||
| N=N, | ||
|  | ||
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -281,10 +281,9 @@ def _assign_block_size(self, node: PrimFuncNode, td: TileDict, block_size: int): | |
|  | ||
| factors = factorize(np.prod(space) // warps) | ||
|  | ||
| def _score(node, thread): # small is better | ||
| def _score(node, warp_tile): # small is better | ||
| score = 0 | ||
| block_tile = [int(np.ceil(tile[i] / thread[i])) for i in range(ndim)] | ||
| shape = node.propagate_inputs_on_reduction(block_tile) | ||
| shape = node.propagate_inputs_on_reduction(warp_tile) | ||
| 
      Comment on lines
    
      +284
     to 
      +286
    
   There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is an excellent correction to the scoring function's logic. The previous implementation was flawed, as it used the number of warps per dimension to estimate data traffic instead of the actual warp tile shape. By passing  | ||
| input_buffers = node.block_analyzer.get_input_buffers(node.reduction_block) | ||
| for i, _ in enumerate(input_buffers): | ||
| score += np.prod(shape[i]) / self.arch.bandwidth[1] | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change corrects a critical bug in the architecture detection logic. The previous implementation had
CUDAandCDNAarchitectures swapped, which would cause incorrect architecture-specific code paths to be taken. This fix ensures the correct platform is identified based on the presence oftorch.version.hip.