diff --git a/benchmarks/dynamo/common.py b/benchmarks/dynamo/common.py index d7d2584167d109..45a9f51ad85fe0 100644 --- a/benchmarks/dynamo/common.py +++ b/benchmarks/dynamo/common.py @@ -1518,9 +1518,15 @@ def run_one_model( ) print(status) if self.args.timing: - from torch._dynamo.utils import print_time_report + from torch._dynamo.utils import op_count, print_time_report + from torch.utils._stats import simple_call_counter print_time_report() + stats = f"STATS: call_* op count: {op_count}" + stats = stats + " | ".join( + f"{key}:{value}" for key, value in simple_call_counter.items() + ) + print(stats) end_calls_captured = torch._dynamo.utils.counters["stats"]["calls_captured"] end_unique_graphs = torch._dynamo.utils.counters["stats"]["unique_graphs"] diff --git a/benchmarks/dynamo/parse_logs.py b/benchmarks/dynamo/parse_logs.py index ab9b7589d52526..a8f882bd204071 100644 --- a/benchmarks/dynamo/parse_logs.py +++ b/benchmarks/dynamo/parse_logs.py @@ -46,7 +46,24 @@ def chunker(seq, size): i = 0 out = csv.writer(sys.stdout, dialect="excel") -out.writerow(["", hash, "", "", "", "", gist_url]) +out.writerow( + [ + "", + hash, + "", + "", + "", + "", + gist_url, + "frame_time", + "backend_time", + "total_ops", + "fake_tensor_dispatch_calls", + "proxy_torch_dispatch_calls", + "time_per_op", + "dispatches_per_op", + ] +) # Sometimes backtraces will be in third party code, which results # in very long file names. Delete the absolute path in this case. @@ -130,6 +147,29 @@ def normalize_file(f): if len(split_str) == 2: backend_time = float(split_str[1]) frame_time = float(split_str[0].split("entire_frame_compile:")[1]) + + tot_ops = None + fm_dispatches = None + pm_dispatches = None + if "STATS:" in log: + result = re.search("STATS:(.*)\n", log).group(1) + # call_* op count: 970 | FakeTensor.__torch_dispatch__:35285 | ProxyTorchDispatchMode.__torch_dispatch__:13339 + split_all = result.split("|") + + if len(split_all) == 3: + tot_ops = int(split_all[0].split("call_* op count:")[1]) + fm_dispatches = int(split_all[1].split("FakeTensor.__torch_dispatch__:")[1]) + pm_dispatches = int( + split_all[2].split("ProxyTorchDispatchMode.__torch_dispatch__:")[1] + ) + time_per_op = None + if frame_time is not None and tot_ops is not None: + time_per_op = frame_time / tot_ops * 1000 # ms + + dispatches_per_op = None + if fm_dispatches is not None and pm_dispatches is not None and tot_ops is not None: + dispatches_per_op = (fm_dispatches + pm_dispatches) / tot_ops + # If the context string is too long, don't put it in the CSV. # This is a hack to try to make it more likely that Google Sheets will # offer to split columns @@ -143,7 +183,22 @@ def normalize_file(f): context = "" out.writerow( - [bench, name, "", r, component, context, explain, frame_time, backend_time] + [ + bench, + name, + "", + r, + component, + context, + explain, + frame_time, + backend_time, + tot_ops, + fm_dispatches, + pm_dispatches, + time_per_op, + dispatches_per_op, + ] ) i += 1 diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 00b4a7ee821963..9d9f9bb8470d26 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -627,6 +627,11 @@ def compile_and_call_fx_graph(self, tx, rv, root): @dynamo_timed(phase_name="backend_compile") def call_user_compiler(self, gm: fx.GraphModule) -> CompiledFn: + tot = 0 + for node in gm.graph.nodes: + if node.op in ("call_function", "call_method", "call_module"): + tot += 1 + torch._dynamo.utils.increment_op_count(tot) try: name = ( self.compiler_fn.__name__ diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index f9acade618af64..ce0893c4db3c09 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -50,7 +50,6 @@ # profiling compilation time compilation_metrics = collections.OrderedDict() - timer_counter = itertools.count() @@ -103,6 +102,14 @@ def reset_frame_count(): curr_frame = 0 +op_count = 0 + + +def increment_op_count(cnt): + global op_count + op_count += cnt + + # Print a report of time spent so far # Ex: # TIMING: diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 608c84802c73d1..79666e935a8fab 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -20,6 +20,7 @@ from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import PyTree, tree_flatten, tree_map, tree_map_only +from torch.utils._stats import count from torch.utils.weak import WeakIdRef pytree = torch.utils._pytree @@ -623,6 +624,7 @@ def __repr__(self): return f"FakeTensor({self_repr}, {self.fake_device})" @classmethod + @count def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # need to handle here to avoid infinite recursion # see [in_kernel_invocation] diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index e3f6903b3ecda2..690f9a41e6b1ae 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -18,6 +18,7 @@ from dataclasses import dataclass import weakref import operator +from torch.utils._stats import count from torch.utils._python_dispatch import TorchDispatchMode, _pop_mode_temporarily, _get_current_dispatch_mode from torch._subclasses import FakeTensor @@ -477,6 +478,7 @@ def __init__(self, tracer, tracing_mode): self.trace_state = {} self._managers = [] + @count def __torch_dispatch__(self, func, types, args=(), kwargs=None): with self.sym_mode.enable(False): return self.inner_torch_dispatch(func, types, args, kwargs) diff --git a/torch/utils/_stats.py b/torch/utils/_stats.py new file mode 100644 index 00000000000000..1e218d9766bb87 --- /dev/null +++ b/torch/utils/_stats.py @@ -0,0 +1,16 @@ +# NOTE! PLEASE KEEP THIS FILE *FREE* OF TORCH DEPS! IT SHOULD BE IMPORTABLE ANYWHERE. +# IF YOU FEEL AN OVERWHELMING URGE TO ADD A TORCH DEP, MAKE A TRAMPOLINE FILE A LA torch._dynamo.utils +# AND SCRUB AWAY TORCH NOTIONS THERE. +import collections +import functools + +simple_call_counter = collections.OrderedDict() + +def count(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if fn.__qualname__ not in simple_call_counter: + simple_call_counter[fn.__qualname__] = 0 + simple_call_counter[fn.__qualname__] = simple_call_counter[fn.__qualname__] + 1 + return fn(*args, **kwargs) + return wrapper