diff --git a/test/profiler/test_profiler.py b/test/profiler/test_profiler.py index cf5e32bfed578..729c380afdea3 100644 --- a/test/profiler/test_profiler.py +++ b/test/profiler/test_profiler.py @@ -1259,7 +1259,6 @@ def test_nested_tensor_with_shapes(self): self.assertTrue(len(e.input_shapes[0]) > 0) - def find_node_with_name(nodes, name): for node in nodes: if node.name() == name: @@ -1268,6 +1267,17 @@ def find_node_with_name(nodes, name): if result is not None: return result + +class SimpleNet(nn.Module): + def __init__(self): + super().__init__() + self.fc1 = nn.Linear(10, 5) + self.fc2 = nn.Linear(5, 2) + + def forward(self, x): + return self.fc2(self.fc1(x)) + + class TestTorchTidyProfiler(TestCase): def test_pointers_and_ids(self): @@ -1336,6 +1346,7 @@ def get_fields(op_name, index): self.assertEqual(c_id, c_id_new) self.assertEqual(d_id, c_id_new) + def test_extra_fields(self): with profile(with_stack=True, profile_memory=True) as p: _ = torch.ones((1,)) @@ -1477,18 +1488,10 @@ def flat_out_extrafields(nodes, out=None): flat_out_extrafields(node.children, out) return out - class simpleNet(nn.Module): - def __init__(self): - super().__init__() - self.fc1 = nn.Linear(10, 5) - self.fc2 = nn.Linear(5, 2) - - def forward(self, x): - return self.fc2(self.fc1(x)) inputs = torch.rand(10) with torch.profiler.profile(with_stack=True, profile_memory=True) as p: - net = simpleNet() + net = SimpleNet() out = net(inputs) modules = flat_out_extrafields(p.profiler.kineto_results.experimental_event_tree()) @@ -1499,6 +1502,32 @@ def forward(self, x): expected += [(name, val.storage().data_ptr()) for name, val in net.fc2._parameters.items()] self.assertEqual(expected, params, f"{expected} vs. {params}") + def test_optimizer(self): + + def flat_out_extrafields(nodes, out=None): + if out is None: + out = [] + for node in nodes: + if isinstance(node.extra_fields, _ExtraFields_PyCall) and node.extra_fields.opt: + out.append(node.extra_fields.opt.self) + flat_out_extrafields(node.children, out) + return out + + inputs = torch.rand(10) + with torch.profiler.profile(with_stack=True, profile_memory=True) as p: + net = SimpleNet() + opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9) + + opt.zero_grad() + out = net(inputs) + loss = torch.nn.functional.cross_entropy(out, torch.rand(2)) + loss.backward() + opt.step() + + opts = flat_out_extrafields(p.profiler.kineto_results.experimental_event_tree()) + self.assertEqual(len(opts), 1, f"Expected 1 optimizer, got {len(opts)}") + self.assertEqual(id(opt), opts[0], f"Optimizer addr ({id(opt)}) vs. Profiled optimizer addr ({opts[0]})") + def test_allocations(self): gc.collect() with profile(profile_memory=True) as p: diff --git a/torch/csrc/autograd/profiler_python.cpp b/torch/csrc/autograd/profiler_python.cpp index debc1f6f83274..fdca317cc88c3 100644 --- a/torch/csrc/autograd/profiler_python.cpp +++ b/torch/csrc/autograd/profiler_python.cpp @@ -37,8 +37,8 @@ namespace torch { namespace profiler { namespace impl { namespace { -enum CallType { PyCall = 0, PyModuleCall, PyCCall }; -static constexpr size_t CallTypeSize = 3; +enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall }; +static constexpr size_t CallTypeSize = 4; using no_ephemeral_t = std::tuple<>; // ============================================================================ @@ -63,7 +63,11 @@ struct CodeLocation { int line_number_{0}; }; -PyCodeObject* nnModuleCode() { +template +PyCodeObject* getCode(); + +template <> +PyCodeObject* getCode() { static auto module_call_code = []() { pybind11::gil_scoped_acquire gil; auto res = py::module::import("torch.nn") @@ -75,14 +79,21 @@ PyCodeObject* nnModuleCode() { return (PyCodeObject*)res; }(); return module_call_code; -} - -template -PyCodeObject* getCode(); +}; template <> -PyCodeObject* getCode() { - return nnModuleCode(); +PyCodeObject* getCode() { + static auto optimizer_step_code = []() { + pybind11::gil_scoped_acquire gil; + auto res = py::module::import("torch.optim") + .attr("Optimizer") + .attr("_optimizer_step_code") + .attr("__code__") + .ptr(); + TORCH_INTERNAL_ASSERT(PyCode_Check(res)); + return (PyCodeObject*)res; + }(); + return optimizer_step_code; }; } // namespace @@ -215,6 +226,25 @@ struct Config { static constexpr EventType event_type = EventType::PyCCall; }; +template <> +struct Config { + using key_t = PyOptimizerSelf; + using cls_t = PyOptimizerCls; + using ephemeral_t = PyFrameObject*; + struct info_t { + cls_t cls_; + std::vector params_; + std::vector> states_; + }; + struct cache_t { + c10::optional + location_; // optim.Optimizer._optimizer_step_code; + ska::flat_hash_map optimizer_data_; + ska::flat_hash_map cls_names_; + }; + static constexpr EventType event_type = EventType::PyCall; +}; + // ============================================================================ // == Callsite & ValueCache: Storage during profiling ========================= // ============================================================================ @@ -245,6 +275,7 @@ class Callsite { using PyCallKey = Config::key_t; using PyModuleCallKey = Config::key_t; using PyCCallKey = Config::key_t; +using PyOptimizerCallKey = Config::key_t; class ValueCache { public: @@ -254,11 +285,11 @@ class ValueCache { template auto load(const Callsite& callsite, size_t python_tid) const { auto caller = load(callsite.caller_); - TORCH_INTERNAL_ASSERT(!caller.second.has_value()); + TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value()); return ExtraFields::event_type>{ /*end_time_ns=*/std::numeric_limits::min(), python_tid, - caller.first, + caller.frame_state_, load(callsite.value_)}; } @@ -354,6 +385,39 @@ ExtraFields::args_t ValueCache::load( cache.cls_names_.at(cls), cache.modules_and_params_.at(key).second}}; } +template <> +void ValueCache::store( + const PyOptimizerCallKey& key, + Config::ephemeral_t frame) { + auto& cache = std::get(state_); + if (C10_UNLIKELY( + cache.optimizer_data_.find(key) == cache.optimizer_data_.end())) { + auto cls = set_class(this, cache, key, frame); + py::list param_groups_handle = + py::handle((PyObject*)key).attr("param_groups"); + std::vector params_; + std::vector> states_; + + cache.optimizer_data_[key] = {cls, params_, states_}; + } +} + +template <> +ExtraFields::args_t ValueCache::load< + CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const { + auto& cache = std::get(state_); + auto cls = cache.optimizer_data_.at(key).cls_; + auto frame_state = std::get(state_).at(*cache.location_); + return { + frame_state, + c10::nullopt, + OptimizerInfo{ + key, + cls, + cache.cls_names_.at(cls), + cache.optimizer_data_.at(key).params_, + cache.optimizer_data_.at(key).states_}}; +} template <> void ValueCache::store( @@ -565,13 +629,16 @@ class PythonTracer final : public python_tracer::PythonTracerBase { torch::profiler::impl::RecordQueue* queue_; PyCodeObject* module_call_code_; + PyCodeObject* optimizer_hook_; std::deque thread_local_results_; ValueCache value_cache_; }; PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue) - : queue_(queue), module_call_code_(nnModuleCode()) { + : queue_(queue), + module_call_code_(getCode()), + optimizer_hook_(getCode()) { TORCH_CHECK(queue_ != nullptr); bool expected{false}; @@ -682,6 +749,14 @@ void PythonTracer::recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame) { TORCH_INTERNAL_ASSERT(back != nullptr); return tls.intern( frame, self.get(), back.get()); + } else if (code.get() == optimizer_hook_) { + auto locals = THPObjectPtr(PyFrame_GetLocals(frame)); + auto self = THPObjectPtr(PyDict_GetItemString(locals, "self")); + Py_INCREF(self.get()); + auto back = THPFrameObjectPtr(PyFrame_GetBack(frame)); + TORCH_INTERNAL_ASSERT(back != nullptr); + return tls.intern( + frame, self.get(), back.get()); } else { auto back = THPFrameObjectPtr(PyFrame_GetBack(frame)); auto f_back = (back.get() != nullptr) ? back.get() : frame; diff --git a/torch/csrc/profiler/collection.h b/torch/csrc/profiler/collection.h index 59ba446b9e8eb..447f4a9a1cd72 100644 --- a/torch/csrc/profiler/collection.h +++ b/torch/csrc/profiler/collection.h @@ -219,6 +219,8 @@ using strong_t = strong:: using PyModuleSelf = strong_t; using PyModuleCls = strong_t; using PyMethod = strong_t; +using PyOptimizerSelf = strong_t; +using PyOptimizerCls = strong_t; struct NNModuleInfo { PyModuleSelf self_; @@ -230,6 +232,15 @@ struct NNModuleInfo { size_t id_{std::numeric_limits::max()}; }; +struct OptimizerInfo { + PyOptimizerSelf self_; + PyOptimizerCls opt_; + at::StringView opt_name_; + + std::vector params_addr_; + std::vector> opt_state_; +}; + struct PyExtraFieldsBase { PyExtraFieldsBase(time_t end_time_ns, size_t python_tid, PyFrameState caller) : end_time_ns_{end_time_ns}, python_tid_{python_tid}, caller_{caller} {} @@ -244,7 +255,11 @@ struct PyExtraFieldsBase { template <> struct ExtraFields : public PyExtraFieldsBase { - using args_t = std::pair>; + using args_t = struct { + PyFrameState frame_state_; + c10::optional module_info_; + c10::optional opt_info_; + }; ExtraFields( time_t end_time_ns, @@ -252,11 +267,13 @@ struct ExtraFields : public PyExtraFieldsBase { PyFrameState caller, args_t args) : PyExtraFieldsBase(end_time_ns, python_tid, caller), - callsite_{args.first}, - module_{args.second} {} + callsite_{args.frame_state_}, + module_{args.module_info_}, + opt_{args.opt_info_} {} PyFrameState callsite_; c10::optional module_; + c10::optional opt_; }; template <> diff --git a/torch/csrc/profiler/python/init.cpp b/torch/csrc/profiler/python/init.cpp index 20630600bd5e9..724098b8ebbdd 100644 --- a/torch/csrc/profiler/python/init.cpp +++ b/torch/csrc/profiler/python/init.cpp @@ -174,8 +174,13 @@ void initPythonBindings(PyObject* module) { .def_property_readonly( "cls_name", [](const NNModuleInfo& s) { return s.cls_name_.str(); }); + py::class_(m, "_OptInfo") + .def_property_readonly("self", [](const OptimizerInfo& a) { + return reinterpret_cast(a.self_.value_of()); + }); + py::class_>(m, "_ExtraFields_PyCall") - .def_readonly("module", &ExtraFields::module_) + .def_readonly("opt", &ExtraFields::opt_) .def_readonly("callsite", &ExtraFields::callsite_) .def_readonly("caller", &ExtraFields::caller_) .def_readonly("module", &ExtraFields::module_); diff --git a/torch/optim/optimizer.py b/torch/optim/optimizer.py index 7b32603babac6..9422b3b0f94d1 100644 --- a/torch/optim/optimizer.py +++ b/torch/optim/optimizer.py @@ -114,6 +114,19 @@ def _cuda_graph_capture_health_check(self): "instance, capturable=True can impair performance, and you should set capturable=False.") self._warned_capturable_if_run_uncaptured = True + def _optimizer_step_code(self): + """Entry point for `torch.profile.profiler`. + + When python tracing is enabled the profiler will hook into this + function at the CPython level to inspect the optimizer's parameters and + param groups. It is called it after `step()` since many optimizers + lazily initialize state. + + This is a workaround due to lack of a proper step hook on the optimizer, + and will be removed if it exists. + """ + pass + def _hook_for_profile(self): self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__) @@ -124,7 +137,10 @@ def wrapper(*args, **kwargs): obj, *_ = args profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__) with torch.autograd.profiler.record_function(profile_name): - return func(*args, **kwargs) + out = func(*args, **kwargs) + obj._optimizer_step_code() + return out + return wrapper hooked = getattr(self.__class__.step, "hooked", None)