From 0ece49b53e773ebc1ea71c7667abc0cbb29d91bf Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 3 Feb 2022 03:05:21 +0900 Subject: [PATCH] wip --- .../tvm/contrib/cutlass/conv2d_operation.py | 19 ++++++++++++++++++- python/tvm/contrib/cutlass/split_k.py | 18 ++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) create mode 100644 python/tvm/contrib/cutlass/split_k.py diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index 5318cc7d74c47..7decf473aac6c 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -197,9 +197,26 @@ def __init__(self): ${align_a}, ${align_b} >::Kernel; +""" + self.reduction_template = + """ +using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + +using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + +using ReductionDevice = cutlass::reduction::device::ReduceSplitK; +using ReductionStrideIndex = typename ReductionDevice::StrideIndex; """ - def emit(self, operation, no_beta_scaling=False, residual_block_info=False): + def emit(self, operation, no_beta_scaling=False, residual_block_info=False, split_k_slices=1): """Instantiate a Conv2d kernel from given `operation`.""" warp_shape = [ int( diff --git a/python/tvm/contrib/cutlass/split_k.py b/python/tvm/contrib/cutlass/split_k.py new file mode 100644 index 0000000000000..aa3d93d7016c7 --- /dev/null +++ b/python/tvm/contrib/cutlass/split_k.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Generate code for parallel Split-k mode"""