Skip to content

Commit

Permalink
Generate benchmarks automatically (#5561)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #5561

## Context

Use the automatic test generation infrastructure to generate operator benchmarks. The overall concept is the same as the test generation; we just structure the generated code in the style of the google benchmark library instead of GTEST.
ghstack-source-id: 244287193

Reviewed By: derekxu, nathanaelsee

Differential Revision: D63286132

fbshipit-source-id: 25c379accf6664dfca8232db81772b638b41c758
  • Loading branch information
SS-JIA authored and facebook-github-bot committed Sep 24, 2024
1 parent ca0e48c commit 5a984cc
Show file tree
Hide file tree
Showing 4 changed files with 595 additions and 24 deletions.
84 changes: 84 additions & 0 deletions backends/vulkan/test/op_tests/generate_op_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os

from typing import Dict

from executorch.backends.vulkan.test.op_tests.cases import test_suites

from executorch.backends.vulkan.test.op_tests.utils.gen_benchmark_vk import (
VkBenchmarkFileGen,
)
from executorch.backends.vulkan.test.op_tests.utils.gen_computegraph import (
ComputeGraphGen,
)
from executorch.backends.vulkan.test.op_tests.utils.test_suite import TestSuite
from torchgen import local

from torchgen.gen import parse_native_yaml, ParsedYaml
from torchgen.model import DispatchKey, NativeFunction


def registry_name(f: NativeFunction) -> str:
name = str(f.namespace) + "." + str(f.func.name)
if len(f.func.name.overload_name) == 0:
name += ".default"
return name


def construct_f_map(parsed_yaml: ParsedYaml) -> Dict[str, NativeFunction]:
f_map: Dict[str, NativeFunction] = {}
for f in parsed_yaml.native_functions:
f_map[registry_name(f)] = f
return f_map


def process_test_suites(
cpp_generator: VkBenchmarkFileGen,
f_map: Dict[str, NativeFunction],
test_suites: Dict[str, TestSuite],
) -> None:
for registry_name, op_test_suite in test_suites.items():
f = f_map[registry_name]
cpp_generator.add_suite(registry_name, f, op_test_suite)


@local.parametrize(
use_const_ref_for_mutable_tensors=False, use_ilistref_for_tensor_lists=False
)
def generate_cpp(
native_functions_yaml_path: str, tags_path: str, output_dir: str
) -> None:
output_file = os.path.join(output_dir, "op_benchmarks.cpp")
cpp_generator = VkBenchmarkFileGen(output_file)

parsed_yaml = parse_native_yaml(native_functions_yaml_path, tags_path)
f_map = construct_f_map(parsed_yaml)

ComputeGraphGen.backend_key = parsed_yaml.backend_indices[DispatchKey.CPU]

process_test_suites(cpp_generator, f_map, test_suites)

with open(output_file, "w") as file:
file.write(cpp_generator.generate_cpp())


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--aten-yaml-path",
help="path to native_functions.yaml file.",
)
parser.add_argument(
"--tags-path",
help="Path to tags.yaml. Required by yaml parsing in gen_correctness_vk system.",
)

parser.add_argument("-o", "--output", help="Output directory", required=True)
args = parser.parse_args()
generate_cpp(args.aten_yaml_path, args.tags_path, args.output)
53 changes: 53 additions & 0 deletions backends/vulkan/test/op_tests/targets.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,19 @@ def define_common_targets(is_fbcode = False):
external_deps = ["torchgen"],
)

runtime.python_library(
name = "generate_op_benchmarks_lib",
srcs = native.glob(["utils/*.py"]) + [
"generate_op_benchmarks.py",
"cases.py",
],
base_module = "executorch.backends.vulkan.test.op_tests",
deps = [
"fbsource//third-party/pypi/expecttest:expecttest",
],
external_deps = ["torchgen"],
)

runtime.python_binary(
name = "generate_op_correctness_tests",
main_module = "executorch.backends.vulkan.test.op_tests.generate_op_correctness_tests",
Expand All @@ -28,6 +41,14 @@ def define_common_targets(is_fbcode = False):
],
)

runtime.python_binary(
name = "generate_op_benchmarks",
main_module = "executorch.backends.vulkan.test.op_tests.generate_op_benchmarks",
deps = [
":generate_op_benchmarks_lib",
],
)

aten_src_path = runtime.external_dep_location("aten-src-path")
genrule_cmd = [
"$(exe :generate_op_correctness_tests)",
Expand All @@ -45,6 +66,22 @@ def define_common_targets(is_fbcode = False):
default_outs = ["."],
)

benchmarks_genrule_cmd = [
"$(exe :generate_op_benchmarks)",
"--tags-path $(location {})/aten/src/ATen/native/tags.yaml".format(aten_src_path),
"--aten-yaml-path $(location {})/aten/src/ATen/native/native_functions.yaml".format(aten_src_path),
"-o $OUT",
]

runtime.genrule(
name = "generated_op_benchmarks_cpp",
outs = {
"op_benchmarks.cpp": ["op_benchmarks.cpp"],
},
cmd = " ".join(benchmarks_genrule_cmd),
default_outs = ["."],
)

pt_operator_library(
name = "all_aten_ops",
check_decl = False,
Expand Down Expand Up @@ -76,6 +113,22 @@ def define_common_targets(is_fbcode = False):
],
)

runtime.cxx_binary(
name = "compute_graph_op_benchmarks_bin",
srcs = [
":generated_op_benchmarks_cpp[op_benchmarks.cpp]",
],
compiler_flags = [
"-Wno-unused-variable",
],
define_static_target = False,
deps = [
"//third-party/benchmark:benchmark",
"//executorch/backends/vulkan:vulkan_graph_runtime",
":all_aten_ops_lib",
],
)

runtime.cxx_test(
name = "compute_graph_op_tests",
srcs = [
Expand Down
Loading

0 comments on commit 5a984cc

Please sign in to comment.