Skip to content

Commit

Permalink
Reduce versbosity in manifest.py (NVIDIA#845)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yinghai Lu authored Mar 7, 2023
1 parent a31b43b commit a68e2f9
Showing 1 changed file with 26 additions and 23 deletions.
49 changes: 26 additions & 23 deletions tools/library/scripts/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
from rank_2k_operation import *
from trmm_operation import *
from symm_operation import *
from conv2d_operation import *
from conv3d_operation import *
from conv2d_operation import *
from conv3d_operation import *
import logging

###################################################################################################
_LOGGER = logging.getLogger(__name__)


class EmitOperationKindLibrary:
def __init__(self, generated_path, kind, args):
Expand All @@ -26,8 +29,8 @@ def __init__(self, generated_path, kind, args):
self.args = args
self.emitters = {
OperationKind.Gemm: EmitGemmConfigurationLibrary
, OperationKind.Conv2d: EmitConv2dConfigurationLibrary
, OperationKind.Conv3d: EmitConv3dConfigurationLibrary
, OperationKind.Conv2d: EmitConv2dConfigurationLibrary
, OperationKind.Conv3d: EmitConv3dConfigurationLibrary
, OperationKind.RankK: EmitRankKConfigurationLibrary
, OperationKind.Rank2K: EmitRank2KConfigurationLibrary
, OperationKind.Trmm: EmitTrmmConfigurationLibrary
Expand Down Expand Up @@ -92,7 +95,7 @@ def emit(self, configuration_name, operations):
with self.emitters[self.kind](self.operation_path, configuration_name) as configuration_emitter:
for operation in operations:
configuration_emitter.emit(operation)

self.source_files.append(configuration_emitter.configuration_path)

self.configurations.append(configuration_name)
Expand Down Expand Up @@ -162,7 +165,7 @@ def emit(self, operation_name):
self.fn_calls.append(SubstituteTemplate(
"\t\t\tinitialize_all_${operation_kind}_operations(manifest);",
{'operation_kind': operation_name}))



#
Expand Down Expand Up @@ -209,21 +212,21 @@ def __init__(self, args = None):
architectures = [x if x != '90a' else '90' for x in architectures]

self.compute_capabilities = [int(x) for x in architectures]

if args.filter_by_cc in ['false', 'False', '0']:
self.filter_by_cc = False

if args.operations == 'all':
self.operations_enabled = []
else:
operations_list = [
OperationKind.Gemm
, OperationKind.Conv2d
, OperationKind.Conv3d
, OperationKind.Conv2d
, OperationKind.Conv3d
, OperationKind.RankK
, OperationKind.Trmm
, OperationKind.Symm
]
]
self.operations_enabled = [x for x in operations_list if OperationKindNames[x] in args.operations.split(',')]

if args.kernels == 'all':
Expand All @@ -248,7 +251,7 @@ def get_kernel_filters (self, kernelListFile):
if os.path.isfile(kernelListFile):
with open(kernelListFile, 'r') as fileReader:
lines = [line.rstrip() for line in fileReader if not line.startswith("#")]

lines = [re.compile(line) for line in lines if line]
return lines
else:
Expand All @@ -260,10 +263,10 @@ def filter_out_kernels(self, kernel_name, kernel_filter_list):
for kernel_filter_re in kernel_filter_list:
if kernel_filter_re.search(kernel_name) is not None:
return True

return False


#
def _filter_string_matches(self, filter_string, haystack):
''' Returns true if all substrings appear in the haystack in order'''
Expand Down Expand Up @@ -316,7 +319,7 @@ def filter(self, operation):
if self._filter_string_matches(name_substr, name):
enabled = False
break

if len(self.kernel_filter_list) > 0:
enabled = False
if self.filter_out_kernels(operation.procedural_name(), self.kernel_filter_list):
Expand All @@ -328,14 +331,14 @@ def filter(self, operation):

#
def append(self, operation):
'''
'''
Inserts the operation.
operation_kind -> configuration_name -> []
'''

if self.filter(operation):

self.selected_kernels.append(operation.procedural_name())

self.operations_by_name[operation.procedural_name()] = operation
Expand All @@ -352,17 +355,17 @@ def append(self, operation):
self.operations[operation.operation_kind][configuration_name].append(operation)
self.operation_count += 1
else:
print("Culled {} from manifest".format(operation.procedural_name()))
_LOGGER.debug("Culled {} from manifest".format(operation.procedural_name()))
#

#
def emit(self, target = GeneratorTarget.Library):

operation_emitters = {
GeneratorTarget.Library: EmitOperationKindLibrary
GeneratorTarget.Library: EmitOperationKindLibrary
}
interface_emitters = {
GeneratorTarget.Library: EmitInterfaceLibrary
GeneratorTarget.Library: EmitInterfaceLibrary
}

generated_path = os.path.join(self.curr_build_dir, 'generated')
Expand Down Expand Up @@ -421,7 +424,7 @@ def for_ampere(name):

def for_turing(name):
return ("1688" in name and "tf32" not in name) or \
"8816" in name
"8816" in name

def for_volta(name):
return "884" in name
Expand Down Expand Up @@ -451,8 +454,8 @@ def get_src_archs_str_given_requested_cuda_archs(archs, source_file):
elif for_volta(source_file):
archs_str = get_src_archs_str_given_requested_cuda_archs({70, 72}, source_file)
else:
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))
raise RuntimeError("Per file archs are not set {}, as there is no rule specified for this file pattern".format(source_file))

manifest_file.write("cutlass_apply_cuda_gencode_flags({} SM_ARCHS {})\n".format(str(source_file.replace('\\', '/')), archs_str))
#

Expand Down

0 comments on commit a68e2f9

Please sign in to comment.