Skip to content

Commit 54de05d

Browse files
authored
[builder] polish builder with better base class (#2216)
* [builder] polish builder * remove print
1 parent 3b1b91e commit 54de05d

File tree

4 files changed

+89
-111
lines changed

4 files changed

+89
-111
lines changed

colossalai/kernel/op_builder/builder.py

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,31 @@ def colossalai_src_path(self, code_path):
3030
else:
3131
return os.path.join(Path(__file__).parent.parent.absolute(), code_path)
3232

33-
def get_cuda_include(self):
33+
def get_cuda_home_include(self):
34+
"""
35+
return include path inside the cuda home.
36+
"""
3437
from torch.utils.cpp_extension import CUDA_HOME
3538
if CUDA_HOME is None:
3639
raise RuntimeError("CUDA_HOME is None, please set CUDA_HOME to compile C++/CUDA kernels in ColossalAI.")
3740
cuda_include = os.path.join(CUDA_HOME, "include")
3841
return cuda_include
3942

43+
# functions must be overrided begin
44+
def sources_files(self):
45+
raise NotImplementedError
46+
47+
def include_dirs(self):
48+
raise NotImplementedError
49+
50+
def cxx_flags(self):
51+
raise NotImplementedError
52+
53+
def nvcc_flags(self):
54+
raise NotImplementedError
55+
56+
# functions must be overrided over
57+
4058
def strip_empty_entries(self, args):
4159
'''
4260
Drop any empty strings from the list of compile and link flags
@@ -57,10 +75,10 @@ def load(self, verbose=True):
5775
start_build = time.time()
5876

5977
op_module = load(name=self.name,
60-
sources=self.strip_empty_entries(self.sources),
61-
extra_include_paths=self.strip_empty_entries(self.extra_include_paths),
62-
extra_cflags=self.extra_cxx_flags,
63-
extra_cuda_cflags=self.extra_cuda_flags,
78+
sources=self.strip_empty_entries(self.sources_files()),
79+
extra_include_paths=self.strip_empty_entries(self.include_dirs()),
80+
extra_cflags=self.cxx_flags(),
81+
extra_cuda_cflags=self.nvcc_flags(),
6482
extra_ldflags=[],
6583
verbose=verbose)
6684

@@ -69,3 +87,18 @@ def load(self, verbose=True):
6987
print(f"Time to load {self.name} op: {build_duration} seconds")
7088

7189
return op_module
90+
91+
def builder(self, name) -> 'CUDAExtension':
92+
"""
93+
get a CUDAExtension instance used for setup.py
94+
"""
95+
from torch.utils.cpp_extension import CUDAExtension
96+
97+
return CUDAExtension(
98+
name=name,
99+
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources_files()],
100+
include_dirs=self.include_dirs(),
101+
extra_compile_args={
102+
'cxx': self.cxx_flags(),
103+
'nvcc': self.nvcc_flags()
104+
})

colossalai/kernel/op_builder/cpu_adam.py

Lines changed: 18 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -12,68 +12,31 @@ def __init__(self):
1212
self.name = CPUAdamBuilder.NAME
1313
super().__init__()
1414

15-
self.sources = [self.colossalai_src_path(path) for path in self.sources_files()]
16-
self.extra_include_paths = [self.colossalai_src_path(path) for path in self.include_paths()]
17-
self.extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native']
18-
self.extra_cuda_flags = [
19-
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
20-
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
21-
]
2215
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
2316

17+
# necessary 4 functions
2418
def sources_files(self):
25-
return [
19+
ret = [
2620
os.path.join(CPUAdamBuilder.BASE_DIR, "csrc/cpu_adam.cpp"),
2721
]
22+
return [self.colossalai_src_path(path) for path in ret]
2823

29-
def include_paths(self):
30-
return [os.path.join(CPUAdamBuilder.BASE_DIR, "includes"), self.get_cuda_include()]
31-
32-
def strip_empty_entries(self, args):
33-
'''
34-
Drop any empty strings from the list of compile and link flags
35-
'''
36-
return [x for x in args if len(x) > 0]
37-
38-
def builder(self, name) -> 'CUDAExtension':
39-
"""
40-
get a CUDAExtension instance used for setup.py
41-
"""
42-
from torch.utils.cpp_extension import CUDAExtension
43-
44-
return CUDAExtension(
45-
name=name,
46-
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
47-
include_dirs=self.extra_include_paths,
48-
extra_compile_args={
49-
'cxx': ['-O3'] + self.version_dependent_macros + self.extra_cxx_flags,
50-
'nvcc':
51-
append_nvcc_threads(['-O3', '--use_fast_math'] + self.version_dependent_macros +
52-
self.extra_cuda_flags)
53-
})
54-
55-
def load(self, verbose=True):
56-
"""
57-
load and compile cpu_adam lib at runtime
58-
59-
Args:
60-
verbose (bool, optional): show detailed info. Defaults to True.
61-
"""
62-
import time
24+
def include_dirs(self):
25+
return [
26+
self.colossalai_src_path(os.path.join(CPUAdamBuilder.BASE_DIR, "includes")),
27+
self.get_cuda_home_include()
28+
]
6329

64-
from torch.utils.cpp_extension import load
65-
start_build = time.time()
30+
def cxx_flags(self):
31+
extra_cxx_flags = ['-std=c++14', '-lcudart', '-lcublas', '-g', '-Wno-reorder', '-fopenmp', '-march=native']
32+
return ['-O3'] + self.version_dependent_macros + extra_cxx_flags
6633

67-
op_module = load(name=self.name,
68-
sources=self.strip_empty_entries(self.sources),
69-
extra_include_paths=self.strip_empty_entries(self.extra_include_paths),
70-
extra_cflags=self.extra_cxx_flags,
71-
extra_cuda_cflags=self.extra_cuda_flags,
72-
extra_ldflags=[],
73-
verbose=verbose)
34+
def nvcc_flags(self):
35+
extra_cuda_flags = [
36+
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
37+
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
38+
]
7439

75-
build_duration = time.time() - start_build
76-
if verbose:
77-
print(f"Time to load {self.name} op: {build_duration} seconds")
40+
return append_nvcc_threads(['-O3', '--use_fast_math'] + self.version_dependent_macros + extra_cuda_flags)
7841

79-
return op_module
42+
# necessary 4 functions
Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
11
import os
2-
import re
3-
4-
import torch
52

63
from .builder import Builder, get_cuda_cc_flag
74

@@ -13,33 +10,26 @@ class FusedOptimBuilder(Builder):
1310
def __init__(self):
1411
self.name = FusedOptimBuilder.NAME
1512
super().__init__()
16-
17-
self.extra_cxx_flags = []
18-
self.extra_cuda_flags = ['-lineinfo']
19-
self.extra_cuda_flags.extend(get_cuda_cc_flag())
20-
21-
self.sources = [self.colossalai_src_path(path) for path in self.sources_files()]
22-
self.extra_include_paths = [self.colossalai_src_path(path) for path in self.include_paths()]
2313
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
2414

2515
def sources_files(self):
26-
return [
27-
os.path.join(FusedOptimBuilder.BASE_DIR, fname) for fname in [
16+
ret = [
17+
self.colossalai_src_path(os.path.join(FusedOptimBuilder.BASE_DIR, fname)) for fname in [
2818
'colossal_C_frontend.cpp', 'multi_tensor_sgd_kernel.cu', 'multi_tensor_scale_kernel.cu',
2919
'multi_tensor_adam.cu', 'multi_tensor_l2norm_kernel.cu', 'multi_tensor_lamb.cu'
3020
]
3121
]
22+
return ret
23+
24+
def include_dirs(self):
25+
ret = [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), self.get_cuda_home_include()]
26+
return [self.colossalai_src_path(path) for path in ret]
27+
28+
def cxx_flags(self):
29+
extra_cxx_flags = []
30+
return ['-O3'] + self.version_dependent_macros + extra_cxx_flags
3231

33-
def include_paths(self):
34-
return [os.path.join(FusedOptimBuilder.BASE_DIR, "includes"), self.get_cuda_include()]
35-
36-
def builder(self, name):
37-
from torch.utils.cpp_extension import CUDAExtension
38-
return CUDAExtension(
39-
name=name,
40-
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
41-
include_dirs=self.extra_include_paths,
42-
extra_compile_args={
43-
'cxx': ['-O3'] + self.version_dependent_macros + self.extra_cxx_flags,
44-
'nvcc': ['-O3', '--use_fast_math'] + self.extra_cuda_flags
45-
})
32+
def nvcc_flags(self):
33+
extra_cuda_flags = ['-lineinfo']
34+
extra_cuda_flags.extend(get_cuda_cc_flag())
35+
return ['-O3', '--use_fast_math'] + extra_cuda_flags

colossalai/kernel/op_builder/multi_head_attn.py

Lines changed: 18 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -9,41 +9,33 @@ def __init__(self):
99
self.base_dir = "cuda_native/csrc"
1010
self.name = 'multihead_attention'
1111
super().__init__()
12-
self.extra_cxx_flags = []
13-
self.extra_cuda_flags = [
14-
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
15-
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
16-
]
17-
18-
self.extra_cuda_flags.extend(get_cuda_cc_flag())
19-
self.sources = [self.colossalai_src_path(path) for path in self.sources_files()]
20-
self.extra_include_paths = [self.colossalai_src_path(path) for path in self.include_paths()]
2112

2213
self.version_dependent_macros = ['-DVERSION_GE_1_1', '-DVERSION_GE_1_3', '-DVERSION_GE_1_5']
2314

15+
def include_dirs(self):
16+
ret = []
17+
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_home_include()]
18+
ret.append(os.path.join(self.base_dir, "kernels", "include"))
19+
return [self.colossalai_src_path(path) for path in ret]
20+
2421
def sources_files(self):
25-
return [
22+
ret = [
2623
os.path.join(self.base_dir, fname) for fname in [
2724
'multihead_attention_1d.cpp', 'kernels/cublas_wrappers.cu', 'kernels/transform_kernels.cu',
2825
'kernels/dropout_kernels.cu', 'kernels/normalize_kernels.cu', 'kernels/softmax_kernels.cu',
2926
'kernels/general_kernels.cu', 'kernels/cuda_util.cu'
3027
]
3128
]
29+
return [self.colossalai_src_path(path) for path in ret]
3230

33-
def include_paths(self):
34-
ret = []
35-
ret = [os.path.join(self.base_dir, "includes"), self.get_cuda_include()]
36-
ret.append(os.path.join(self.base_dir, "kernels", "include"))
37-
print("include_paths", ret)
38-
return ret
31+
def cxx_flags(self):
32+
return ['-O3'] + self.version_dependent_macros
3933

40-
def builder(self, name):
41-
from torch.utils.cpp_extension import CUDAExtension
42-
return CUDAExtension(
43-
name=name,
44-
sources=[os.path.join('colossalai/kernel/cuda_native/csrc', path) for path in self.sources],
45-
include_dirs=self.extra_include_paths,
46-
extra_compile_args={
47-
'cxx': ['-O3'] + self.version_dependent_macros,
48-
'nvcc': ['-O3', '--use_fast_math'] + self.extra_cuda_flags
49-
})
34+
def nvcc_flags(self):
35+
extra_cuda_flags = [
36+
'-std=c++14', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__',
37+
'-U__CUDA_NO_HALF2_OPERATORS__', '-DTHRUST_IGNORE_CUB_VERSION_CHECK'
38+
]
39+
extra_cuda_flags.extend(get_cuda_cc_flag())
40+
ret = ['-O3', '--use_fast_math'] + extra_cuda_flags
41+
return ret

0 commit comments

Comments
 (0)