Skip to content

Commit

Permalink
fix alignmentC for h16816_s8xf16 (NVIDIA#1146)
Browse files Browse the repository at this point in the history
* fix alignmentC for h16816_s8xf16

* manish's change
  • Loading branch information
hwu36 authored Oct 17, 2023
1 parent 757275f commit 5e1a0a5
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions python/cutlass_library/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2225,7 +2225,7 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
math_inst.element_accumulator,
]

CreateGemmOperator(manifest, layouts, tile_descriptions, \
operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type, alignment_constraints)

# Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation)
Expand All @@ -2238,11 +2238,12 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version):
math_inst.element_accumulator,
]

operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \
operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \
data_type_mixed, alignment_constraints)

for op in operations:
if op.tile_description.threadblock_shape[1] <= 32:
if (DataTypeSize[op.C.element] == 16) and \
(op.tile_description.threadblock_shape[1] <= 32):
op.C.alignment = 4


Expand Down

0 comments on commit 5e1a0a5

Please sign in to comment.