diff --git a/torchbenchmark/operators_collection/__init__.py b/torchbenchmark/operators_collection/__init__.py new file mode 100644 index 000000000..9a60deb76 --- /dev/null +++ b/torchbenchmark/operators_collection/__init__.py @@ -0,0 +1,71 @@ +import importlib +import pathlib +from typing import List + +OP_COLLECTION_PATH = "operators_collection" + + +def list_operator_collections() -> List[str]: + """ + List the available operator collections. + + This function retrieves the list of available operator collections by scanning the directories + in the current path that contain an "__init__.py" file. + + Returns: + List[str]: A list of names of the available operator collections. + """ + p = pathlib.Path(__file__).parent + # only load the directories that contain a "__init__.py" file + collection_paths = sorted( + str(child.absolute()) + for child in p.iterdir() + if child.is_dir() and child.joinpath("__init__.py").exists() + ) + filtered_collections = [pathlib.Path(path).name for path in collection_paths] + return filtered_collections + + +def list_operators_by_collection(op_collection: str = "default") -> List[str]: + """ + List the operators from the specified operator collections. + + This function retrieves the list of operators from the specified operator collections. + If the collection name is "all", it retrieves operators from all available collections. + If the collection name is not specified, it defaults to the "default" collection. + + Args: + op_collection (str): Names of the operator collections to list operators from. + It can be a single collection name or a comma-separated list of names. + Special value "all" retrieves operators from all collections. + + Returns: + List[str]: A list of operator names from the specified collection(s). + + Raises: + ModuleNotFoundError: If the specified collection module is not found. + AttributeError: If the specified collection module does not have a 'get_operators' function. + """ + + def _list_all_operators(collection_name: str): + try: + module_name = f".{collection_name}" + module = importlib.import_module(module_name, package=__name__) + if hasattr(module, "get_operators"): + return module.get_operators() + else: + raise AttributeError( + f"Module '{module_name}' does not have a 'get_operators' function" + ) + except ModuleNotFoundError: + raise ModuleNotFoundError(f"Module '{module_name}' not found") + + if op_collection == "all": + collection_names = list_operator_collections() + else: + collection_names = op_collection.split(",") + + all_operators = [] + for collection_name in collection_names: + all_operators.extend(_list_all_operators(collection_name)) + return all_operators diff --git a/torchbenchmark/operators_collection/all/__init__.py b/torchbenchmark/operators_collection/all/__init__.py new file mode 100644 index 000000000..6adb7ef8b --- /dev/null +++ b/torchbenchmark/operators_collection/all/__init__.py @@ -0,0 +1,5 @@ +from torchbenchmark.operators import list_operators + + +def get_operators(): + return list_operators() diff --git a/torchbenchmark/operators_collection/default/__init__.py b/torchbenchmark/operators_collection/default/__init__.py new file mode 100644 index 000000000..ed898630c --- /dev/null +++ b/torchbenchmark/operators_collection/default/__init__.py @@ -0,0 +1,28 @@ +from torchbenchmark.operators_collection.all import get_operators as get_all_operators +from torchbenchmark.operators_collection.liger import ( + get_operators as get_liger_operators, +) + + +def get_operators(): + """ + Retrieve the list of operators for the default collection. + + This function retrieves the list of operators for the default collection by + comparing the operators from the 'all' collection and the 'liger' collection. + It returns a list of operators that are present in the 'all' collection but + not in the 'liger' collection. + + In the future, if we add more operator collections, we will need to update + this function to exclude desired operators in other collections. + + other_collections = list_operator_collections() + to_remove = set(other_collections).union(liger_operators) + return [item for item in all_operators if item not in to_remove] + + Returns: + List[str]: A list of operator names for the default collection. + """ + all_operators = get_all_operators() + liger_operators = get_liger_operators() + return [item for item in all_operators if item not in liger_operators] diff --git a/torchbenchmark/operators_collection/liger/__init__.py b/torchbenchmark/operators_collection/liger/__init__.py new file mode 100644 index 000000000..4d955aa3e --- /dev/null +++ b/torchbenchmark/operators_collection/liger/__init__.py @@ -0,0 +1,5 @@ +liger_operators = ["FusedLinearCrossEntropy"] + + +def get_operators(): + return liger_operators diff --git a/userbenchmark/triton/run.py b/userbenchmark/triton/run.py index 32dd41dd3..222eef75c 100644 --- a/userbenchmark/triton/run.py +++ b/userbenchmark/triton/run.py @@ -7,6 +7,7 @@ from torch import version as torch_version from torchbenchmark.operator_loader import load_opbench_by_name_from_loader from torchbenchmark.operators import load_opbench_by_name +from torchbenchmark.operators_collection import list_operators_by_collection from torchbenchmark.util.triton_op import ( BenchmarkOperatorResult, @@ -36,6 +37,13 @@ def get_parser(args=None): required=False, help="Operators to benchmark. Split with comma if multiple.", ) + parser.add_argument( + "--op-collection", + default="default", + type=str, + help="Operator collections to benchmark. Split with comma." + " It is conflict with --op. Choices: [default, liger, all]", + ) parser.add_argument( "--mode", choices=["fwd", "bwd", "fwd_bwd", "fwd_no_grad"], @@ -158,8 +166,10 @@ def get_parser(args=None): args, extra_args = parser.parse_known_args(args) if args.op and args.ci: parser.error("cannot specify operator when in CI mode") - elif not args.op and not args.ci: - parser.error("must specify operator when not in CI mode") + if not args.op and not args.op_collection: + print( + "Neither operator nor operator collection is specified. Running all operators in the default collection." + ) return parser @@ -221,7 +231,8 @@ def run(args: List[str] = []): if args.op: ops = args.op.split(",") else: - ops = [] + ops = list_operators_by_collection(args.op_collection) + with gpu_lockdown(args.gpu_lockdown): for op in ops: args.op = op