Skip to content

Commit

Permalink
[Profiler] tracking Optimizer (part 2 of Record Optimizer) (pytorch#8…
Browse files Browse the repository at this point in the history
…4920)

Summary:
Part 2 of Record Optimizer param_groups and states (pytorch#84063)
- hooking from optimizer step
- PyOptCall Type
- declare data type for collection
- python binding
- simple unit test case

Test Plan: buck run mode/opt //caffe2/test:profiler

Differential Revision: D39402667

Pull Request resolved: pytorch#84920
Approved by: https://github.com/robieta
  • Loading branch information
slgong-fb authored and pytorchmergebot committed Sep 28, 2022
1 parent 1c0f0b3 commit f80ef73
Show file tree
Hide file tree
Showing 5 changed files with 169 additions and 27 deletions.
49 changes: 39 additions & 10 deletions test/profiler/test_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1259,7 +1259,6 @@ def test_nested_tensor_with_shapes(self):
self.assertTrue(len(e.input_shapes[0]) > 0)



def find_node_with_name(nodes, name):
for node in nodes:
if node.name() == name:
Expand All @@ -1268,6 +1267,17 @@ def find_node_with_name(nodes, name):
if result is not None:
return result


class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)

def forward(self, x):
return self.fc2(self.fc1(x))


class TestTorchTidyProfiler(TestCase):

def test_pointers_and_ids(self):
Expand Down Expand Up @@ -1336,6 +1346,7 @@ def get_fields(op_name, index):
self.assertEqual(c_id, c_id_new)
self.assertEqual(d_id, c_id_new)


def test_extra_fields(self):
with profile(with_stack=True, profile_memory=True) as p:
_ = torch.ones((1,))
Expand Down Expand Up @@ -1477,18 +1488,10 @@ def flat_out_extrafields(nodes, out=None):
flat_out_extrafields(node.children, out)
return out

class simpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)

def forward(self, x):
return self.fc2(self.fc1(x))

inputs = torch.rand(10)
with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
net = simpleNet()
net = SimpleNet()
out = net(inputs)

modules = flat_out_extrafields(p.profiler.kineto_results.experimental_event_tree())
Expand All @@ -1499,6 +1502,32 @@ def forward(self, x):
expected += [(name, val.storage().data_ptr()) for name, val in net.fc2._parameters.items()]
self.assertEqual(expected, params, f"{expected} vs. {params}")

def test_optimizer(self):

def flat_out_extrafields(nodes, out=None):
if out is None:
out = []
for node in nodes:
if isinstance(node.extra_fields, _ExtraFields_PyCall) and node.extra_fields.opt:
out.append(node.extra_fields.opt.self)
flat_out_extrafields(node.children, out)
return out

inputs = torch.rand(10)
with torch.profiler.profile(with_stack=True, profile_memory=True) as p:
net = SimpleNet()
opt = torch.optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

opt.zero_grad()
out = net(inputs)
loss = torch.nn.functional.cross_entropy(out, torch.rand(2))
loss.backward()
opt.step()

opts = flat_out_extrafields(p.profiler.kineto_results.experimental_event_tree())
self.assertEqual(len(opts), 1, f"Expected 1 optimizer, got {len(opts)}")
self.assertEqual(id(opt), opts[0], f"Optimizer addr ({id(opt)}) vs. Profiled optimizer addr ({opts[0]})")

def test_allocations(self):
gc.collect()
with profile(profile_memory=True) as p:
Expand Down
99 changes: 87 additions & 12 deletions torch/csrc/autograd/profiler_python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ namespace torch {
namespace profiler {
namespace impl {
namespace {
enum CallType { PyCall = 0, PyModuleCall, PyCCall };
static constexpr size_t CallTypeSize = 3;
enum CallType { PyCall = 0, PyModuleCall, PyCCall, PyOptimizerCall };
static constexpr size_t CallTypeSize = 4;
using no_ephemeral_t = std::tuple<>;

// ============================================================================
Expand All @@ -63,7 +63,11 @@ struct CodeLocation {
int line_number_{0};
};

PyCodeObject* nnModuleCode() {
template <CallType C>
PyCodeObject* getCode();

template <>
PyCodeObject* getCode<CallType::PyModuleCall>() {
static auto module_call_code = []() {
pybind11::gil_scoped_acquire gil;
auto res = py::module::import("torch.nn")
Expand All @@ -75,14 +79,21 @@ PyCodeObject* nnModuleCode() {
return (PyCodeObject*)res;
}();
return module_call_code;
}

template <CallType C>
PyCodeObject* getCode();
};

template <>
PyCodeObject* getCode<CallType::PyModuleCall>() {
return nnModuleCode();
PyCodeObject* getCode<CallType::PyOptimizerCall>() {
static auto optimizer_step_code = []() {
pybind11::gil_scoped_acquire gil;
auto res = py::module::import("torch.optim")
.attr("Optimizer")
.attr("_optimizer_step_code")
.attr("__code__")
.ptr();
TORCH_INTERNAL_ASSERT(PyCode_Check(res));
return (PyCodeObject*)res;
}();
return optimizer_step_code;
};

} // namespace
Expand Down Expand Up @@ -215,6 +226,25 @@ struct Config<CallType::PyCCall> {
static constexpr EventType event_type = EventType::PyCCall;
};

template <>
struct Config<CallType::PyOptimizerCall> {
using key_t = PyOptimizerSelf;
using cls_t = PyOptimizerCls;
using ephemeral_t = PyFrameObject*;
struct info_t {
cls_t cls_;
std::vector<void*> params_;
std::vector<std::pair<std::string, void*>> states_;
};
struct cache_t {
c10::optional<CodeLocation>
location_; // optim.Optimizer._optimizer_step_code;
ska::flat_hash_map<key_t, info_t> optimizer_data_;
ska::flat_hash_map<cls_t, at::StringView> cls_names_;
};
static constexpr EventType event_type = EventType::PyCall;
};

// ============================================================================
// == Callsite & ValueCache: Storage during profiling =========================
// ============================================================================
Expand Down Expand Up @@ -245,6 +275,7 @@ class Callsite {
using PyCallKey = Config<CallType::PyCall>::key_t;
using PyModuleCallKey = Config<CallType::PyModuleCall>::key_t;
using PyCCallKey = Config<CallType::PyCCall>::key_t;
using PyOptimizerCallKey = Config<CallType::PyOptimizerCall>::key_t;

class ValueCache {
public:
Expand All @@ -254,11 +285,11 @@ class ValueCache {
template <CallType C>
auto load(const Callsite<C>& callsite, size_t python_tid) const {
auto caller = load<CallType::PyCall>(callsite.caller_);
TORCH_INTERNAL_ASSERT(!caller.second.has_value());
TORCH_INTERNAL_ASSERT(!caller.module_info_.has_value());
return ExtraFields<Config<C>::event_type>{
/*end_time_ns=*/std::numeric_limits<time_t>::min(),
python_tid,
caller.first,
caller.frame_state_,
load<C>(callsite.value_)};
}

Expand Down Expand Up @@ -354,6 +385,39 @@ ExtraFields<EventType::PyCall>::args_t ValueCache::load<CallType::PyModuleCall>(
cache.cls_names_.at(cls),
cache.modules_and_params_.at(key).second}};
}
template <>
void ValueCache::store<CallType::PyOptimizerCall>(
const PyOptimizerCallKey& key,
Config<CallType::PyOptimizerCall>::ephemeral_t frame) {
auto& cache = std::get<CallType::PyOptimizerCall>(state_);
if (C10_UNLIKELY(
cache.optimizer_data_.find(key) == cache.optimizer_data_.end())) {
auto cls = set_class<CallType::PyOptimizerCall>(this, cache, key, frame);
py::list param_groups_handle =
py::handle((PyObject*)key).attr("param_groups");
std::vector<void*> params_;
std::vector<std::pair<std::string, void*>> states_;

cache.optimizer_data_[key] = {cls, params_, states_};
}
}

template <>
ExtraFields<EventType::PyCall>::args_t ValueCache::load<
CallType::PyOptimizerCall>(const PyOptimizerCallKey& key) const {
auto& cache = std::get<CallType::PyOptimizerCall>(state_);
auto cls = cache.optimizer_data_.at(key).cls_;
auto frame_state = std::get<CallType::PyCall>(state_).at(*cache.location_);
return {
frame_state,
c10::nullopt,
OptimizerInfo{
key,
cls,
cache.cls_names_.at(cls),
cache.optimizer_data_.at(key).params_,
cache.optimizer_data_.at(key).states_}};
}

template <>
void ValueCache::store<CallType::PyCCall>(
Expand Down Expand Up @@ -565,13 +629,16 @@ class PythonTracer final : public python_tracer::PythonTracerBase {

torch::profiler::impl::RecordQueue* queue_;
PyCodeObject* module_call_code_;
PyCodeObject* optimizer_hook_;

std::deque<ThreadLocalResults> thread_local_results_;
ValueCache value_cache_;
};

PythonTracer::PythonTracer(torch::profiler::impl::RecordQueue* queue)
: queue_(queue), module_call_code_(nnModuleCode()) {
: queue_(queue),
module_call_code_(getCode<CallType::PyModuleCall>()),
optimizer_hook_(getCode<CallType::PyOptimizerCall>()) {
TORCH_CHECK(queue_ != nullptr);

bool expected{false};
Expand Down Expand Up @@ -682,6 +749,14 @@ void PythonTracer::recordPyCall(ThreadLocalResults& tls, PyFrameObject* frame) {
TORCH_INTERNAL_ASSERT(back != nullptr);
return tls.intern<CallType::PyModuleCall, E>(
frame, self.get(), back.get());
} else if (code.get() == optimizer_hook_) {
auto locals = THPObjectPtr(PyFrame_GetLocals(frame));
auto self = THPObjectPtr(PyDict_GetItemString(locals, "self"));
Py_INCREF(self.get());
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
TORCH_INTERNAL_ASSERT(back != nullptr);
return tls.intern<CallType::PyOptimizerCall, E>(
frame, self.get(), back.get());
} else {
auto back = THPFrameObjectPtr(PyFrame_GetBack(frame));
auto f_back = (back.get() != nullptr) ? back.get() : frame;
Expand Down
23 changes: 20 additions & 3 deletions torch/csrc/profiler/collection.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ using strong_t = strong::
using PyModuleSelf = strong_t<PyObject*, struct PyModuleSelf_>;
using PyModuleCls = strong_t<PyObject*, struct PyModuleCls_>;
using PyMethod = strong_t</*PyMethodDef*/ void*, struct PyMethod_>;
using PyOptimizerSelf = strong_t<PyObject*, struct PyOptSelf_>;
using PyOptimizerCls = strong_t<PyObject*, struct PyOptimizer_>;

struct NNModuleInfo {
PyModuleSelf self_;
Expand All @@ -230,6 +232,15 @@ struct NNModuleInfo {
size_t id_{std::numeric_limits<size_t>::max()};
};

struct OptimizerInfo {
PyOptimizerSelf self_;
PyOptimizerCls opt_;
at::StringView opt_name_;

std::vector<void*> params_addr_;
std::vector<std::pair<std::string, void*>> opt_state_;
};

struct PyExtraFieldsBase {
PyExtraFieldsBase(time_t end_time_ns, size_t python_tid, PyFrameState caller)
: end_time_ns_{end_time_ns}, python_tid_{python_tid}, caller_{caller} {}
Expand All @@ -244,19 +255,25 @@ struct PyExtraFieldsBase {

template <>
struct ExtraFields<EventType::PyCall> : public PyExtraFieldsBase {
using args_t = std::pair<PyFrameState, c10::optional<NNModuleInfo>>;
using args_t = struct {
PyFrameState frame_state_;
c10::optional<NNModuleInfo> module_info_;
c10::optional<OptimizerInfo> opt_info_;
};

ExtraFields(
time_t end_time_ns,
size_t python_tid,
PyFrameState caller,
args_t args)
: PyExtraFieldsBase(end_time_ns, python_tid, caller),
callsite_{args.first},
module_{args.second} {}
callsite_{args.frame_state_},
module_{args.module_info_},
opt_{args.opt_info_} {}

PyFrameState callsite_;
c10::optional<NNModuleInfo> module_;
c10::optional<OptimizerInfo> opt_;
};

template <>
Expand Down
7 changes: 6 additions & 1 deletion torch/csrc/profiler/python/init.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,13 @@ void initPythonBindings(PyObject* module) {
.def_property_readonly(
"cls_name", [](const NNModuleInfo& s) { return s.cls_name_.str(); });

py::class_<OptimizerInfo>(m, "_OptInfo")
.def_property_readonly("self", [](const OptimizerInfo& a) {
return reinterpret_cast<intptr_t>(a.self_.value_of());
});

py::class_<ExtraFields<EventType::PyCall>>(m, "_ExtraFields_PyCall")
.def_readonly("module", &ExtraFields<EventType::PyCall>::module_)
.def_readonly("opt", &ExtraFields<EventType::PyCall>::opt_)
.def_readonly("callsite", &ExtraFields<EventType::PyCall>::callsite_)
.def_readonly("caller", &ExtraFields<EventType::PyCall>::caller_)
.def_readonly("module", &ExtraFields<EventType::PyCall>::module_);
Expand Down
18 changes: 17 additions & 1 deletion torch/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,19 @@ def _cuda_graph_capture_health_check(self):
"instance, capturable=True can impair performance, and you should set capturable=False.")
self._warned_capturable_if_run_uncaptured = True

def _optimizer_step_code(self):
"""Entry point for `torch.profile.profiler`.
When python tracing is enabled the profiler will hook into this
function at the CPython level to inspect the optimizer's parameters and
param groups. It is called it after `step()` since many optimizers
lazily initialize state.
This is a workaround due to lack of a proper step hook on the optimizer,
and will be removed if it exists.
"""
pass

def _hook_for_profile(self):
self._zero_grad_profile_name = "Optimizer.zero_grad#{}.zero_grad".format(self.__class__.__name__)

Expand All @@ -124,7 +137,10 @@ def wrapper(*args, **kwargs):
obj, *_ = args
profile_name = "Optimizer.step#{}.step".format(obj.__class__.__name__)
with torch.autograd.profiler.record_function(profile_name):
return func(*args, **kwargs)
out = func(*args, **kwargs)
obj._optimizer_step_code()
return out

return wrapper

hooked = getattr(self.__class__.step, "hooked", None)
Expand Down

0 comments on commit f80ef73

Please sign in to comment.