-
Notifications
You must be signed in to change notification settings - Fork 355
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Generate benchmarks automatically (#5561)
Summary: Pull Request resolved: #5561 ## Context Use the automatic test generation infrastructure to generate operator benchmarks. The overall concept is the same as the test generation; we just structure the generated code in the style of the google benchmark library instead of GTEST. ghstack-source-id: 244287193 Reviewed By: derekxu, nathanaelsee Differential Revision: D63286132 fbshipit-source-id: 25c379accf6664dfca8232db81772b638b41c758
- Loading branch information
1 parent
ca0e48c
commit 5a984cc
Showing
4 changed files
with
595 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import argparse | ||
import os | ||
|
||
from typing import Dict | ||
|
||
from executorch.backends.vulkan.test.op_tests.cases import test_suites | ||
|
||
from executorch.backends.vulkan.test.op_tests.utils.gen_benchmark_vk import ( | ||
VkBenchmarkFileGen, | ||
) | ||
from executorch.backends.vulkan.test.op_tests.utils.gen_computegraph import ( | ||
ComputeGraphGen, | ||
) | ||
from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite | ||
from torchgen import local | ||
|
||
from torchgen.gen import parse_native_yaml, ParsedYaml | ||
from torchgen.model import DispatchKey, NativeFunction | ||
|
||
|
||
def registry_name(f: NativeFunction) -> str: | ||
name = str(f.namespace) + "." + str(f.func.name) | ||
if len(f.func.name.overload_name) == 0: | ||
name += ".default" | ||
return name | ||
|
||
|
||
def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]: | ||
f_map: Dict[str, NativeFunction] = {} | ||
for f in parsed_yaml.native_functions: | ||
f_map[registry_name(f)] = f | ||
return f_map | ||
|
||
|
||
def process_test_suites( | ||
cpp_generator: VkBenchmarkFileGen, | ||
f_map: Dict[str, NativeFunction], | ||
test_suites: Dict[str, TestSuite], | ||
) -> None: | ||
for registry_name, op_test_suite in test_suites.items(): | ||
f = f_map[registry_name] | ||
cpp_generator.add_suite(registry_name, f, op_test_suite) | ||
|
||
|
||
@local.parametrize( | ||
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False | ||
) | ||
def generate_cpp( | ||
native_functions_yaml_path: str, tags_path: str, output_dir: str | ||
) -> None: | ||
output_file = os.path.join(output_dir, "op_benchmarks.cpp") | ||
cpp_generator = VkBenchmarkFileGen(output_file) | ||
|
||
parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path) | ||
f_map = construct_f_map(parsed_yaml) | ||
|
||
ComputeGraphGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU] | ||
|
||
process_test_suites(cpp_generator, f_map, test_suites) | ||
|
||
with open(output_file, "w") as file: | ||
file.write(cpp_generator.generate_cpp()) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"--aten-yaml-path", | ||
help="path to native_functions.yaml file.", | ||
) | ||
parser.add_argument( | ||
"--tags-path", | ||
help="Path to tags.yaml. Required by yaml parsing in gen_correctness_vk system.", | ||
) | ||
|
||
parser.add_argument("-o", "--output", help="Output directory", required=True) | ||
args = parser.parse_args() | ||
generate_cpp(args.aten_yaml_path, args.tags_path, args.output) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.