Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 67 additions & 12 deletions superbench/benchmarks/micro_benchmarks/cublaslt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,20 @@ def add_parser_arguments(self):
required=False,
help='Number of steps to measure for autotune.',
)
self._parser.add_argument(
'--enable_ncu_profiling',
action='store_true',
required=False,
help='Enable ncu profiling for each run.',
)
self._parser.add_argument(
'--profiling_metrics',
type=str,
nargs='+',
default=None,
required=False,
help='List of ncu profiling metrics, support all ncu metrics.',
)

def _preprocess(self):
"""Preprocess/preparation operations before the benchmarking.
Expand All @@ -75,16 +89,17 @@ def _preprocess(self):
f' -a -W {self._args.num_warmup_autotune}'
f' -I {self._args.num_steps_autotune}'
) if self._args.enable_autotune else ''

self._commands.append(
f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} '
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}'
command = f'{self.__bin_path} -m {_m} -n {_n} -k {_k} -b {_b} ' + \
f'-w {self._args.num_warmup} -i {self._args.num_steps} -t {_in_type}' + \
f'{(" " + autotune_args) if autotune_args else ""}'
)
if self._args.enable_ncu_profiling:
skip_num = self._args.num_warmup - 1 if self._args.num_warmup > 1 else 0
command = f'ncu --set full --launch-skip {skip_num} --launch-count 1 --csv ' + command
self._commands.append(command)

return True

def _process_raw_result(self, cmd_idx, raw_output):
def _process_raw_result(self, cmd_idx, raw_output): # noqa: C901
"""Function to parse raw results and save the summarized results.

self._result.add_raw_data() and self._result.add_result() need to be called to save the results.
Expand All @@ -99,12 +114,52 @@ def _process_raw_result(self, cmd_idx, raw_output):
self._result.add_raw_data(f'raw_output_{cmd_idx}', raw_output, self._args.log_raw_data)

try:
fields = raw_output.strip().split()
if len(fields) != 6 or not all(x.isdigit() for x in fields[:4]):
raise ValueError('Invalid result.')
self._result.add_result(
f'{self._commands[cmd_idx].split()[-1]}_{fields[3]}_{"_".join(fields[:3])}_flops', float(fields[-1])
)
if not self._args.enable_ncu_profiling:
fields = raw_output.strip().split()
if len(fields) != 6 or not all(x.isdigit() for x in fields[:4]):
raise ValueError('Invalid result.')
self._result.add_result(
f'{self._commands[cmd_idx].split()[-1]}_{fields[3]}_{"_".join(fields[:3])}_flops',
float(fields[-1])
)
else:
lines = raw_output.strip().split('\n')
# find line index of the line that starts with "ID","Process ID"
start_idx = next(i for i, line in enumerate(lines) if 'Metric Name' in line)
if start_idx == 0 or start_idx == len(lines) - 1:
raise ValueError('Invalid result.')
result_lines = lines[0:start_idx - 1]
result = False
size = ''
for line in result_lines:
fields = line.strip().split()
if len(fields) == 6 and all(x.isdigit() for x in fields[:4]):
result = True
size = f'{fields[3]}_{"_".join(fields[:3])}'
self._result.add_result(
f'{self._commands[cmd_idx].split()[-1]}_{fields[3]}_{"_".join(fields[:3])}_flops',
float(fields[-1])
)
if not result:
raise ValueError('Invalid result.')
metric_name_index = lines[start_idx].strip().split(',').index('"Metric Name"')
metric_value_index = lines[start_idx].strip().split(',').index('"Metric Value"')
if metric_name_index < 0 or metric_value_index < 0:
raise ValueError('Can not find Metric Name and Value.')
for line in lines[start_idx + 1:]:
fields = line.strip().split('","')
metric_name = fields[metric_name_index].strip('"').replace(' ', '_')
if len(fields) < 15:
continue
if not self._args.profiling_metrics or metric_name in self._args.profiling_metrics:
value = fields[metric_value_index].strip(',').strip('"')
try:
float_value = float(value)
self._result.add_result(
f'{self._commands[cmd_idx].split()[-1]}_{size}_{metric_name}', float_value
)
except ValueError:
pass
except BaseException as e:
self._result.set_return_code(ReturnCode.MICROBENCHMARK_RESULT_PARSING_FAILURE)
logger.error(
Expand Down
20 changes: 18 additions & 2 deletions tests/benchmarks/micro_benchmarks/test_cublaslt_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import unittest
from types import GeneratorType, SimpleNamespace

from tests.helper import decorator
from tests.helper.testcase import BenchmarkTestCase
from superbench.benchmarks import BenchmarkRegistry, BenchmarkType, ReturnCode, Platform
from superbench.benchmarks.result import BenchmarkResult
Expand Down Expand Up @@ -83,11 +84,14 @@ def cmd(t, b, m, n, k):
for _m in [32, 128]:
self.assertIn(cmd(_t, _b, _m, 128, 128), benchmark._commands)

def test_cublaslt_gemm_result_parsing(self):
@decorator.load_data('tests/data/cublaslt_ncu.log')
def test_cublaslt_gemm_result_parsing(self, raw_output):
"""Test cublaslt-gemm benchmark result parsing."""
benchmark = self.get_benchmark()
self.assertTrue(benchmark._preprocess())
benchmark._args = SimpleNamespace(shapes=['16,16,16', '32,64,128'], in_types=['fp8e4m3'], log_raw_data=False)
benchmark._args = SimpleNamespace(
shapes=['16,16,16', '32,64,128'], in_types=['fp8e4m3'], log_raw_data=False, enable_ncu_profiling=False
)
benchmark._result = BenchmarkResult(self.benchmark_name, BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=1)

# Positive case - valid raw output
Expand All @@ -101,3 +105,15 @@ def test_cublaslt_gemm_result_parsing(self):

# Negative case - invalid raw output
self.assertFalse(benchmark._process_raw_result(1, 'cuBLAS API failed'))

# Positive case - valid ncu raw output
benchmark._args = SimpleNamespace(
shapes=['2208,2048,5608'],
in_types=['fp8e4m3'],
log_raw_data=False,
enable_ncu_profiling=True,
profiling_metrics=['DRAM_Throughput'],
)
benchmark._result = BenchmarkResult(self.benchmark_name, BenchmarkType.MICRO, ReturnCode.SUCCESS, run_count=1)
self.assertTrue(benchmark._process_raw_result(1, raw_output))
self.assertEqual(0.74, benchmark.result['fp8e4m3_0_2208_2048_5608_DRAM_Throughput'][0])
Loading
Loading