Skip to content

Commit

Permalink
[CUTLASS] Fix hardcoded include path and logic for profile_all = Fals…
Browse files Browse the repository at this point in the history
…e case (#9399)

* fixed hardcoded cutlass include path

* fixed profile_all = False case

* add cutlass cmake option

* check if cutlass path exists

* improve err msg when cutlass is not found
  • Loading branch information
masahi authored Oct 30, 2021
1 parent 3a889e7 commit 4087e72
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 9 deletions.
4 changes: 4 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -347,3 +347,7 @@ set(USE_PAPI OFF)
# Note that cmake will use `find_package` to find GTest. Please use cmake's
# predefined variables to specify the path to the GTest package if needed.
set(USE_GTEST AUTO)

# Enable using CUTLASS as a BYOC backend
# Need to have USE_CUDA=ON
set(USE_CUTLASS OFF)
2 changes: 1 addition & 1 deletion cmake/modules/contrib/CUTLASS.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.

if(USE_CUTLASS)
if(USE_CUDA AND USE_CUTLASS)
file(GLOB CUTLASS_RELAY_CONTRIB_SRC src/relay/backend/contrib/cutlass/*.cc)
list(APPEND COMPILER_SRCS ${CUTLASS_RELAY_CONTRIB_SRC})

Expand Down
24 changes: 19 additions & 5 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,24 @@
# under the License.
# pylint: disable=invalid-name
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
import os
import tvm
from tvm import runtime, relay
from .gen_gemm import CutlassGemmProfiler


def _get_cutlass_path():
tvm_root = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../../../../")
cutlass_path = os.path.join(tvm_root, "3rdparty/cutlass")
assert os.path.exists(
cutlass_path
), """The CUTLASS root directory not found in {}.
Currently, using CUTLASS requires building TVM from source.""".format(
cutlass_path
)
return cutlass_path


class GemmAnnotator(tvm.relay.ExprVisitor):
"""Annotates partitioned functions with shape and dtype information."""

Expand Down Expand Up @@ -70,7 +83,7 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
num_cutlass_partition : int
The number of partitioned functions created for CUTLASS.
"""
cutlass_profiler = CutlassGemmProfiler(sm, "../../../3rdparty/cutlass", tmp_dir)
cutlass_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir)
num_cutlass_partition = 0
for var in mod.get_global_vars():
fun_name = var.name_hint
Expand Down Expand Up @@ -152,8 +165,9 @@ def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so"):
updated_lib : runtime.Module
The updated module with compiled cutlass kernels.
"""
cutlass_path = "../../../3rdparty/cutlass/include"
cutlass_util_path = "../../../3rdparty/cutlass/tools/util/include"
cutlass_root = _get_cutlass_path()
cutlass_include = os.path.join(cutlass_root, "include")
cutlass_util_include = os.path.join(cutlass_root, "tools/util/include")

kwargs = {}
kwargs["cc"] = "nvcc"
Expand All @@ -165,8 +179,8 @@ def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so"):
"-Xcompiler=-fno-strict-aliasing",
"-O3",
"-std=c++14",
"-I" + cutlass_path,
"-I" + cutlass_util_path,
"-I" + cutlass_include,
"-I" + cutlass_util_include,
]
lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs)
return runtime.load_module(lib_path)
10 changes: 7 additions & 3 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,12 @@ def compile_all(self, ops, use_multiprocessing=False):
for op in ops:
self._compile(op)

def evaluate(self, op_name, args):
def evaluate(self, op, args):
"""Run the profiler executable corresponding to op_name with args."""
op_name = op["name"]
opath = os.path.join(self.binary_prefix, op_name)
if not os.path.exists(opath):
self._compile(op)
cmd = [opath]
if args is not None:
cmd.append(str(args[0]))
Expand Down Expand Up @@ -342,10 +345,11 @@ def profile(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=Fals
for op in ops:
op["runtime"] = -1

self.engine.compile_all(ops, use_multiprocessing)
if profile_all:
self.engine.compile_all(ops, use_multiprocessing)

for op in ops:
out = self.engine.evaluate(op["name"], [M, N, K])
out = self.engine.evaluate(op, [M, N, K])
op["runtime"] = out
if out > 0 and profile_all is False:
break
Expand Down

0 comments on commit 4087e72

Please sign in to comment.