Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 4a383e2 commit 7a51995
Showing 1 changed file with 17 additions and 24 deletions.
41 changes: 17 additions & 24 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,22 +174,14 @@ def __init__(self):
${unary_op}
>"""

self.epilogue_wgrad_split_k = """
self.epilogue_wgrad = """
${epilogue_functor}<
${element_c},
4,
float,
float,
>"""

self.epilogue_wgrad_split_k_tmp = """
${epilogue_functor}<
float,
4,
float,
float,
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
Expand Down Expand Up @@ -252,26 +244,19 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
)

element_c = operation.C.element
use_split_k = (
operation.split_k_slices > 1
and operation.conv_kind == ConvKind.Wgrad
and operation.C.element == DataType.f16
)
element_c_gemm = element_c
use_split_k_wgrad = operation.conv_kind == ConvKind.Wgrad and operation.split_k_slices > 1

if use_split_k:
# split k
element_c = DataType.f32
if use_split_k_wgrad:
# split k, assumes fp32 accum. gemm output always fp32
element_c_gemm = DataType.f32
epilogue_reduction = substitute_template(
self.epilogue_wgrad_split_k,
self.epilogue_wgrad,
{
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_c": DataTypeTag[element_c],
},
)
epilogue_gemm = substitute_template(
self.epilogue_wgrad_split_k_tmp,
{"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor]},
)
reduction = substitute_template(
self.reduction_template,
{
Expand All @@ -292,7 +277,7 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"layout_a": LayoutTag[operation.A.layout],
"element_b": DataTypeTag[operation.B.element],
"layout_b": LayoutTag[operation.B.layout],
"element_c": DataTypeTag[element_c],
"element_c": DataTypeTag[element_c_gemm],
"layout_c": LayoutTag[operation.C.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[
Expand Down Expand Up @@ -332,7 +317,15 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"conv_kernel_postfix": "",
}

if use_split_k:
if use_split_k_wgrad:
# Even if the output is fp16, gemm output is always fp32 for split k wgrad.
epilogue_gemm = substitute_template(
self.epilogue_wgrad,
{
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_c": "float",
},
)
template = substitute_template(self.template, {"epilogue": epilogue_gemm})
elif residual_block_info:
template = substitute_template(
Expand Down

0 comments on commit 7a51995

Please sign in to comment.