Skip to content

Commit

Permalink
Add cutlass conv2d profiler
Browse files Browse the repository at this point in the history
commit 1c0bbb2
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Dec 12 18:29:03 2021 +0900

    fix lint

commit 463574c
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Dec 12 17:28:38 2021 +0900

    fixed conv2d check

commit 588c5ab
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Dec 12 15:05:27 2021 +0900

    update test

commit a447b57
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Dec 12 14:54:52 2021 +0900

    speed up profiling by removing initialization

commit 93cd039
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Dec 12 08:26:29 2021 +0900

    fixed nhwc cudnn depthwise conv

commit 6db7172
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat Dec 11 15:39:05 2021 +0900

    add cache

commit f7d17a1
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sat Dec 11 15:05:38 2021 +0900

    removed im2col profiling for conv2d

commit b724f44
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 22:57:54 2021 +0900

    black

commit fe4687b
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 22:49:13 2021 +0900

    fixed cmd arguement

commit ab114f5
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 22:22:19 2021 +0900

    conv2d profiler working

commit 49ee61f
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 20:26:15 2021 +0900

    add conv2d profiler

commit 49e2c89
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Sun Dec 12 08:03:36 2021 +0900

    do not offload depthwise conv2d

commit cd83677
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 13:20:01 2021 +0900

    lint fix

commit 870823c
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 12:54:38 2021 +0900

    add comment on IC == 3 case

commit 6b780db
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 12:48:33 2021 +0900

    check align on N dim

commit 308c4da
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 12:34:42 2021 +0900

    fixed check functions for fused cases, run infer type before mergecomposite

commit 8d6a1bf
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 12:10:59 2021 +0900

    test IC=3 convolution

commit ffce47d
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 12:10:16 2021 +0900

    use align1 kernel for unusual channel cases (IC = 3 etc)

commit 6cdf205
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 12:06:56 2021 +0900

    add dtype and layout check in parttern match

commit 7743cc6
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 10:40:53 2021 +0900

    add sm75 kernels to sm80 profilings

commit efceccb
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 10:40:42 2021 +0900

    skip legalize when batch size is dynamic

commit 65fbc0a
Author: Masahiro Masuda <masahi129@gmail.com>
Date:   Fri Dec 10 10:36:36 2021 +0900

    bug fix in im2col encoding
  • Loading branch information
masahi committed Dec 14, 2021
1 parent 69cae0b commit 3d3d24f
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 34 deletions.
9 changes: 9 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ def handle_conv2d(
d_shape,
w_shape,
out_shape,
padding,
strides,
dilation,
out_dtype,
profile_all,
use_multiprocessing,
Expand All @@ -198,6 +201,9 @@ def handle_conv2d(
d_shape,
w_shape,
out_shape,
padding,
strides,
dilation,
out_dtype,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
Expand Down Expand Up @@ -279,6 +285,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
arg0_shape,
arg1_shape,
annotator.signature["ret_shape"],
annotator.op_attrs.padding,
annotator.op_attrs.strides,
annotator.op_attrs.dilation,
out_dtype,
profile_all,
use_multiprocessing,
Expand Down
163 changes: 163 additions & 0 deletions python/tvm/contrib/cutlass/conv2d_profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-outside-toplevel, invalid-name
"""Instantiate a C++ source for profiling CUTLASS kernels."""


class Conv2dProfilerEmitter(object):
"""Emit a C++ source for profiling CUTLASS kernels."""

def __init__(self):
from jinja2 import Template

self.template = Template(
"""
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#define CUTLASS_CHECK(status) \
{ \
cutlass::Status error = status; \
if (error != cutlass::Status::kSuccess) { \
std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \
<< std::endl; \
exit(EXIT_FAILURE); \
} \
}
{{OperatorDef}}
using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution<{{OperatorName}}>;
struct Options {
cutlass::Tensor4DCoord input_size;
cutlass::Tensor4DCoord filter_size;
cutlass::Tensor4DCoord padding;
cutlass::MatrixCoord conv_stride;
cutlass::MatrixCoord dilation;
void parse(int argc, char const **args) {
cutlass::CommandLine cmd(argc, args);
cmd.get_cmd_line_argument("n", input_size.n());
cmd.get_cmd_line_argument("h", input_size.h());
cmd.get_cmd_line_argument("w", input_size.w());
cmd.get_cmd_line_argument("c", input_size.c());
cmd.get_cmd_line_argument("k", filter_size.n());
cmd.get_cmd_line_argument("r", filter_size.h());
cmd.get_cmd_line_argument("s", filter_size.w());
int pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w;
cmd.get_cmd_line_argument("pad_h", pad_h);
cmd.get_cmd_line_argument("pad_w", pad_w);
cmd.get_cmd_line_argument("stride_h", stride_h);
cmd.get_cmd_line_argument("stride_w", stride_w);
cmd.get_cmd_line_argument("dilation_h", dilation_h);
cmd.get_cmd_line_argument("dilation_w", dilation_w);
filter_size.c() = input_size.c();
padding = {pad_h, pad_h, pad_w, pad_w};
conv_stride = {stride_h, stride_w};
dilation = {dilation_h, dilation_w};
}
cutlass::Tensor4DCoord output_size() const {
auto dilated_h = (filter_size.h() - 1) * dilation.row() + 1;
auto dilated_w = (filter_size.w() - 1) * dilation.column() + 1;
auto h = (input_size.h() + padding.n() + padding.h() - dilated_h) / conv_stride.row() + 1;
auto w = (input_size.w() + padding.w() + padding.c() - dilated_w) / conv_stride.column() + 1;
return cutlass::Tensor4DCoord(input_size.n(), h, w, filter_size.n());
}
};
double profile_convolution(Options const &options) {
using ElementOutput = typename ImplicitGemm::ElementC;
using ElementInputA = typename ImplicitGemm::ElementA;
using ElementInputB = typename ImplicitGemm::ElementB;
auto oshape = options.output_size();
cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA> tensor_a(options.input_size);
cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB> tensor_b(options.filter_size);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_c(oshape);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_ref_c(oshape);
cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
options.padding,
options.conv_stride,
options.dilation,
options.output_size(),
cutlass::conv::Mode::kCrossCorrelation,
1
);
using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_c.device_ref(),
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
};
ImplicitGemm implicit_gemm_op;
size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace.get());
CUTLASS_CHECK(status);
status = implicit_gemm_op();
CUTLASS_CHECK(status);
cudaEvent_t events[2];
for (auto & event : events) {
cudaEventCreate(&event);
}
cudaEventRecord(events[0]);
for (int iteration = 0; iteration < 100; ++iteration) {
auto status = implicit_gemm_op();
CUTLASS_CHECK(status);
}
cudaEventRecord(events[1]);
cudaEventSynchronize(events[1]);
float runtime_ms = 0;
cudaEventElapsedTime(&runtime_ms, events[0], events[1]);
for (auto event : events) {
(void)cudaEventDestroy(event);
}
return double(runtime_ms) / 100.0;
}
int main(int argc, char const **args) {
Options options;
options.parse(argc, args);
std::cout << profile_convolution(options) << std::endl;
return 0;
}
"""
)

def emit(self, op_def, op_name):
src = self.template.render(OperatorDef=op_def, OperatorName=op_name)
return src
81 changes: 68 additions & 13 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,14 @@
# under the License.
# pylint: disable=invalid-name
"""Conv2d kernel generator and profiler for CUTLASS."""
import re
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .gen_gemm import CutlassGemmProfiler
from .conv2d_profiler import Conv2dProfilerEmitter
from .gen_tensor_op import (
ProfilerEngine,
GENERATOR_FUNC_TABLE,
)
from .library import (
EpilogueFunctor,
SwizzlingFunctor,
Expand All @@ -39,6 +45,7 @@ def create_conv2d_operator(
ret = []

kernel_emitter = EmitConv2dInstance()
profiler_emitter = Conv2dProfilerEmitter()

element_a, element_b, element_c, element_epilogue = data_type
iterator_algorithms = [IteratorAlgorithm.Optimized]
Expand Down Expand Up @@ -75,6 +82,7 @@ def create_conv2d_operator(
# TODO(masahi): Add profiler source here
op_entry["opdef"] = kernel_emitter.emit(op)
op_entry["op"] = op
op_entry["src"] = profiler_emitter.emit(op_entry["opdef"], op.procedural_name())
op_entry["name"] = op.procedural_name()
op_entry["runtime"] = 9999999

Expand Down Expand Up @@ -113,6 +121,9 @@ class CutlassConv2DProfiler:
def __init__(self, sm, cutlass_path, binary_path):
self.gemm_profiler = CutlassGemmProfiler(sm, cutlass_path, binary_path)
self.sm = sm
assert sm in GENERATOR_FUNC_TABLE, "sm%d not supported yet." % sm
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.cache = {}

def get_default(self, out_dtype):
gemm_profile_result = self.gemm_profiler.get_default(out_dtype)
Expand All @@ -121,27 +132,71 @@ def get_default(self, out_dtype):
data_type = gemm_profile_result["data_type"]
return create_conv2d_operator([tile_description], data_type, [alignment])[0]

def check_align(self, op_name, C, K):
"""Filter out kernels that cannot be supported."""
aligns = re.findall(r"align[1|2|4|8]", op_name)
assert len(aligns) == 1
align = int(aligns[0][-1])
return all([dim % align == 0 for dim in [C, K]])

def profile(
self, d_shape, w_shape, out_shape, out_dtype, profile_all=True, use_multiprocessing=False
self,
d_shape,
w_shape,
out_shape,
padding,
stride,
dilation,
out_dtype,
profile_all=True,
use_multiprocessing=False,
):
"""Profile and select the best kernel from candidate kernels.
If profile_all is False, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
B, _, _, IC = d_shape
N, H, W, IC = d_shape
OC, R, S, _ = w_shape
_, P, Q, _ = out_shape
workload = (
N,
H,
W,
IC,
OC,
R,
S,
padding[0],
padding[1],
stride[0],
stride[1],
dilation[0],
dilation[1],
)

M = B * P * Q
N = OC
K = R * S * IC
if workload in self.cache:
return self.cache[workload]

gemm_profile_result = self.gemm_profiler.profile(
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
)
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=create_conv2d_operator)
ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))

tile_description = gemm_profile_result["tile_description"]
alignment = gemm_profile_result["alignment"]
data_type = gemm_profile_result["data_type"]
for op in ops:
op["runtime"] = -1

return create_conv2d_operator([tile_description], data_type, [alignment])[0]
if profile_all:
self.engine.compile_all(ops, use_multiprocessing)

args = (
"--n=%d --h=%d --w=%d --c=%d --k=%d --r=%d --s=%d --pad_h=%d --pad_w=%d "
"--stride_h=%d --stride_w=%d --dilation_h=%d --dilation_w=%d"
) % workload

for op in ops:
out = self.engine.evaluate(op, args.split(" "))
op["runtime"] = out
if out > 0 and profile_all is False:
break

valid_ops = filter(lambda op: op["runtime"] > 0, ops)
output = sorted(valid_ops, key=lambda i: i["runtime"])
self.cache[workload] = output[0]
return output[0]
9 changes: 1 addition & 8 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@
from .gemm_profiler import GemmProfilerEmitter
from .gen_tensor_op import (
ProfilerEngine,
generate_sm75_tensor_op_1688,
generate_sm80_tensor_op_16816,
GENERATOR_FUNC_TABLE,
)
from .library import (
EpilogueFunctor,
Expand Down Expand Up @@ -132,12 +131,6 @@ def create_gemm_operator(
return ret


GENERATOR_FUNC_TABLE = {
75: generate_sm75_tensor_op_1688,
80: generate_sm80_tensor_op_16816,
}


# TODO(masahi): A sensible way to pick reasonable default kernels
DEFAULT_KERNELS = {
75: {
Expand Down
14 changes: 8 additions & 6 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def get_tile_descriptions(math_inst):
return sm75_kernels + sm80_kernels


GENERATOR_FUNC_TABLE = {
75: generate_sm75_tensor_op_1688,
80: generate_sm80_tensor_op_16816,
}


class ProfilerEngine:
"""Compile and run a given profiler executable."""

Expand Down Expand Up @@ -204,12 +210,8 @@ def evaluate(self, op, args):
if not os.path.exists(opath):
self._compile(op)
cmd = [opath]
if args is not None:
cmd.append(str(args[0]))
cmd.append(str(args[1]))
cmd.append(str(args[2]))
if len(args) > 3:
cmd.append(str(args[3]))
for arg in args:
cmd.append(str(arg))
try:
sp = subprocess.run(cmd, capture_output=True, check=True)
rt = float(sp.stdout)
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,10 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
plevel=25,
)

elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
elif (
is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups)
and "cudnn" not in target.libs
):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
Expand Down
Loading

0 comments on commit 3d3d24f

Please sign in to comment.