Skip to content

Commit

Permalink
[pt2] Add logger logging for remote fx graph cache get + put (pytorch…
Browse files Browse the repository at this point in the history
…#138164)

Summary:
X-link: pytorch/benchmark#2512

Pull Request resolved: pytorch#138164

Capture the timing for the remote fx graph cache get and put operations and add them to the logger logging.

Test Plan:
1) Landed D64483593 and waited for logger actualization.
2) Ran test script on devserver: `buck2 run mode/opt scripts/slarsen/torch_compile_model:run`
3) Queried dynamo_compile/sandbox:
```
(pytorch-3.10_4) devvm2296:~/local/pytorch-3.10_4  $ scuba -e="select time,co_filename,remote_fx_graph_cache_get_time_s,remote_fx_graph_cache_put_time_s from \`dynamo_compile/sandbox\` where remote_fx_graph_cache_put_time_s is not null"
+------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------+----------------------------------+
|    time    |                                                                                    co_filename                                                                                    | remote_fx_graph_cache_get_time_s | remote_fx_graph_cache_put_time_s |
+------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------+----------------------------------+
| 1729136266 | null                                                                                                                                                                              |              0.05652284622192383 |               0.9691152572631836 |
| 1729136263 | /data/users/slarsen/fbsource/buck-out/v2/gen/fbcode/289bb46b326874c6/scripts/slarsen/torch_compile_model/__run__/run-inplace#link-tree/scripts/slarsen/torch_compile_model/run.py |               0.8298435211181641 |              0.18642282485961914 |
+------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+----------------------------------+----------------------------------+
```

Reviewed By: ezyang, oulgen

Differential Revision: D64484025
  • Loading branch information
masnesral authored and facebook-github-bot committed Oct 18, 2024
1 parent c88b77a commit 17950b3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 0 deletions.
11 changes: 11 additions & 0 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@
record_compilation_metrics,
reset_graph_break_dup_checker,
setup_compile_debug,
to_int_ms,
troubleshooting_url,
write_record_to_file,
)
Expand Down Expand Up @@ -1053,6 +1054,12 @@ def format_guard_failures() -> str:
"auto_functionalize",
{"missed_reinplacing_bytes": possibly_missed_reinplacing_bytes},
)
remote_fx_graph_cache_get_time = frame_phase_timing[frame_key].get(
"remote_fx_graph_cache_get", None
)
remote_fx_graph_cache_put_time = frame_phase_timing[frame_key].get(
"remote_fx_graph_cache_put", None
)
else:
guard_count = None
shape_env_guard_count = None
Expand All @@ -1070,6 +1077,8 @@ def format_guard_failures() -> str:
dynamo_time_before_restart = time.time() - start_time
possibly_missed_reinplacing_opportunities = None
remote_cache_time_saved = None
remote_fx_graph_cache_get_time = None
remote_fx_graph_cache_put_time = None

structured_logging_overhead_s = (
torch._logging.get_structured_logging_overhead()
Expand Down Expand Up @@ -1121,6 +1130,8 @@ def handle_sets(d: Dict[str, Any]) -> Dict[str, Any]:
config.inline_inbuilt_nn_modules,
config.specialize_float,
json.dumps(config_dict),
to_int_ms(remote_fx_graph_cache_get_time),
to_int_ms(remote_fx_graph_cache_put_time),
)
record_compilation_metrics(metrics)
torch._dynamo.callback_handler.run_end_callbacks()
Expand Down
18 changes: 18 additions & 0 deletions torch/_dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,10 +341,18 @@ def dynamo_timed(
remote_cache_time_saved = frame_phase_timing[
compile_id
].get("remote_cache_time_saved", None)
remote_fx_graph_cache_get_time = frame_phase_timing[
compile_id
].get("remote_fx_graph_cache_get", None)
remote_fx_graph_cache_put_time = frame_phase_timing[
compile_id
].get("remote_fx_graph_cache_put", None)
else:
inductor_compile_time = None
code_gen_time = None
remote_cache_time_saved = None
remote_fx_graph_cache_get_time = None
remote_fx_graph_cache_put_time = None
structured_logging_overhead_s = (
torch._logging.get_structured_logging_overhead()
)
Expand All @@ -356,6 +364,8 @@ def dynamo_timed(
fail_reason,
remote_cache_time_saved,
structured_logging_overhead_s,
to_int_ms(remote_fx_graph_cache_get_time),
to_int_ms(remote_fx_graph_cache_put_time),
)
record_compilation_metrics(metrics)

Expand Down Expand Up @@ -762,6 +772,10 @@ def proxy_args_kwargs(args, kwargs):
)


def to_int_ms(v: Optional[float]) -> Optional[int]:
return None if v is None else int(v * 1000)


@dataclasses.dataclass
class CompilationMetrics:
is_forward: bool = dataclasses.field(default=True, init=False)
Expand Down Expand Up @@ -801,6 +815,8 @@ class CompilationMetrics:
config_inline_inbuilt_nn_modules: Optional[bool]
specialize_float: Optional[bool]
dynamo_config: Optional[str]
remote_fx_graph_cache_get_time_ms: Optional[int]
remote_fx_graph_cache_put_time_ms: Optional[int]


@dataclasses.dataclass
Expand All @@ -813,6 +829,8 @@ class BwdCompilationMetrics:
fail_reason: Optional[str]
remote_cache_time_saved_s: Optional[float]
structured_logging_overhead_s: Optional[float]
remote_fx_graph_cache_get_time_ms: Optional[int]
remote_fx_graph_cache_put_time_ms: Optional[int]


DEFAULT_COMPILATION_METRICS_LIMIT = 64
Expand Down

0 comments on commit 17950b3

Please sign in to comment.