Skip to content

Commit 2ea4b56

Browse files
shengfukevinpytorchmergebot
authored andcommitted
Record min/max of integral tensor in ET (pytorch#143088)
Summary: In et-replay, random data is used to run the operators. However, it does not work well for the op that uses index to access tensor. For example, embedding ops, which use the indices to look up the embedding table. If random data is used for these index ops, et-replay usually runs into invalid memory access issue. To fix it, ET provides an environment variable "ENABLE_PYTORCH_EXECUTION_TRACE_INTEGRAL_TENSOR_RANGE", if it is set, ET will capture the min/max value of the flattened integral tensor. Then in et_replay, the min/max is used to generate the random tensor within that range. It fixed invalid memory access issue. Test Plan: buck2 run mode/opt caffe2/test:test_profiler_cuda -- profiler.test_execution_trace.TestExecutionTraceCUDA.test_execution_trace_record_integral_tensor_range_cuda Differential Revision: D66666931 Pull Request resolved: pytorch#143088 Approved by: https://github.com/sanrise
1 parent bceedee commit 2ea4b56

File tree

2 files changed

+123
-14
lines changed

2 files changed

+123
-14
lines changed

test/profiler/test_execution_trace.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
pass
1515

1616
import json
17+
import os
1718
import sys
1819
import tempfile
1920
import unittest
@@ -411,6 +412,38 @@ def fn(nt):
411412
found_cos = True
412413
assert found_cos
413414

415+
@unittest.skipIf(
416+
not TEST_CUDA,
417+
"need CUDA device availability to run",
418+
)
419+
def test_execution_trace_record_integral_tensor_range(self):
420+
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
421+
fp.close()
422+
423+
os.environ["ENABLE_PYTORCH_EXECUTION_TRACE_INTEGRAL_TENSOR_RANGE"] = "1"
424+
t1 = torch.tensor([[1, 2], [3, 4]]).cuda()
425+
t2 = torch.tensor([[0, 0], [1, 0]]).cuda()
426+
with profile(
427+
activities=supported_activities(),
428+
schedule=torch.profiler.schedule(
429+
skip_first=0, wait=0, warmup=0, active=1, repeat=1
430+
),
431+
record_shapes=True,
432+
execution_trace_observer=(
433+
ExecutionTraceObserver().register_callback(fp.name)
434+
),
435+
) as p:
436+
torch.gather(t1, 1, t2)
437+
p.step()
438+
439+
nodes = self.get_execution_trace_root(fp.name)
440+
for n in nodes:
441+
assert "name" in n
442+
if "aten::gather" in n["name"]:
443+
for attr in n["attrs"]:
444+
if attr["name"] == "tensor_range":
445+
assert attr["value"] == '{"0":[1,4],"1":[0,1]}'
446+
414447

415448
devices = ["cpu", "cuda"]
416449
if TEST_XPU:

torch/csrc/profiler/standalone/execution_trace_observer.cpp

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,8 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT
153153
state_ = newState;
154154
}
155155

156+
bool record_integral_tensor_range{false};
157+
156158
private:
157159
static bool callbackShouldBeEnabled(RunState run_state) {
158160
return run_state == ExecutionTraceObserver::RunState::enabled;
@@ -189,6 +191,28 @@ struct FunctionCallContext : public ObserverContext { // NOLINT
189191
std::vector<std::string> inputShapes;
190192
std::vector<std::string> inputStrides;
191193
std::vector<std::string> inputValues;
194+
std::map<int, std::pair<long, long>> tensor_index_min_max_map;
195+
196+
std::string get_string_for_tensor_range() {
197+
if (tensor_index_min_max_map.empty()) {
198+
return "";
199+
}
200+
201+
std::string result = "{";
202+
unsigned int i = 0;
203+
for (auto const& [key, value] : tensor_index_min_max_map) {
204+
if (i == tensor_index_min_max_map.size() - 1) {
205+
result += json_str_escape(
206+
fmt::format("\"{}\":[{},{}]", key, value.first, value.second));
207+
} else {
208+
result += json_str_escape(
209+
fmt::format("\"{}\":[{},{}],", key, value.first, value.second));
210+
}
211+
i++;
212+
}
213+
result += "}";
214+
return result;
215+
}
192216
};
193217

194218
// Opens the json file to write the execution trace.
@@ -240,14 +264,15 @@ static void writeJsonNode(
240264
const std::string& operator_schema = "",
241265
const std::string& kernelBackend = "",
242266
const std::string& kernelFile = "",
267+
const std::string& tensor_range = "",
243268
const std::string& additiona_attrs = "") {
244269
out << fmt::format(
245270
R"JSON(
246271
{{
247272
"id": {}, "name": "{}", "ctrl_deps": {},
248273
"inputs": {{"values": {}, "shapes": {}, "types": {}, "strides": {}}},
249274
"outputs": {{"values": {}, "shapes": {}, "types": {}, "strides": {}}},
250-
"attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}}{}]
275+
"attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}},{{"name": "tensor_range", "type": "string", "value": "{}"}}{}]
251276
}})JSON",
252277
id,
253278
name,
@@ -269,6 +294,7 @@ static void writeJsonNode(
269294
operator_schema,
270295
kernelBackend,
271296
kernelFile,
297+
tensor_range,
272298
additiona_attrs);
273299
}
274300

@@ -354,6 +380,9 @@ static ExecutionTraceObserver::ID getObjectID(
354380
static std::tuple<std::string, std::string, std::string, std::string>
355381
convertIValue(
356382
ExecutionTraceObserver& ob,
383+
int& tensorIndex,
384+
std::map<int, std::pair<long, long>>& tensor_index_min_max_map,
385+
bool isInput,
357386
const c10::IValue& val,
358387
const bool baseType = true,
359388
const size_t maxArrayLen = kMaxNumElements) {
@@ -391,7 +420,18 @@ convertIValue(
391420
numel = tensor_impl->numel();
392421
itemsize = tensor_impl->itemsize();
393422
device_str = tensor_impl->device().str();
423+
424+
if (ob.record_integral_tensor_range && isInput &&
425+
at::isIntegralType(tensor.scalar_type(), false) &&
426+
tensor.numel() != 0) {
427+
enableRecordFunction(false);
428+
long min = tensor.min().item().toLong();
429+
long max = tensor.max().item().toLong();
430+
enableRecordFunction(true);
431+
tensor_index_min_max_map[tensorIndex] = std::make_pair(min, max);
432+
}
394433
}
434+
tensorIndex++;
395435
tensor_value = fmt::format(
396436
"[{},{},{},{},{},\"{}\"]",
397437
tensor_id,
@@ -410,7 +450,14 @@ convertIValue(
410450
std::vector<std::string> type_array;
411451
std::vector<std::string> value_array;
412452
for (const auto j : c10::irange(tuple_size)) {
413-
auto tuple = convertIValue(ob, val_tuple[j], false, maxArrayLen);
453+
auto tuple = convertIValue(
454+
ob,
455+
tensorIndex,
456+
tensor_index_min_max_map,
457+
isInput,
458+
val_tuple[j],
459+
false,
460+
maxArrayLen);
414461
shape_array.push_back(std::get<0>(tuple));
415462
stride_array.push_back(std::get<1>(tuple));
416463
type_array.push_back(std::get<2>(tuple));
@@ -431,7 +478,14 @@ convertIValue(
431478
std::vector<std::string> type_array;
432479
std::vector<std::string> value_array;
433480
for (const auto j : c10::irange(list_size)) {
434-
auto tuple = convertIValue(ob, val_list.get(j), false, maxArrayLen);
481+
auto tuple = convertIValue(
482+
ob,
483+
tensorIndex,
484+
tensor_index_min_max_map,
485+
isInput,
486+
val_list.get(j),
487+
false,
488+
maxArrayLen);
435489
shape_array.push_back(std::get<0>(tuple));
436490
stride_array.push_back(std::get<1>(tuple));
437491
type_array.push_back(std::get<2>(tuple));
@@ -462,13 +516,16 @@ convertIValue(
462516

463517
static void appendValueInfo(
464518
ExecutionTraceObserver& ob,
519+
int& tensorIndex,
520+
std::map<int, std::pair<long, long>>& tensor_index_min_max_map,
521+
bool isInput,
465522
const c10::IValue& val,
466523
std::vector<std::string>& shapes,
467524
std::vector<std::string>& strides,
468525
std::vector<std::string>& types,
469526
std::vector<std::string>& values) {
470-
auto tuple = convertIValue(ob, val, true);
471-
527+
auto tuple = convertIValue(
528+
ob, tensorIndex, tensor_index_min_max_map, isInput, val, true);
472529
shapes.push_back(std::get<0>(tuple));
473530
strides.push_back(std::get<1>(tuple));
474531
types.push_back(std::get<2>(tuple));
@@ -529,9 +586,10 @@ inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT
529586
}
530587

531588
// get NcclMeta from record function, this used ParamCommsDebugInfo above
532-
// since we currently have this read called in onFunctionExit flow, we should
533-
// only introspect output tensors to prevent an INTERNAL ASSERT FAILED in
534-
// RecordFunction when we try to read input in RecordFunction exit methods.
589+
// since we currently have this read called in onFunctionExit flow, we
590+
// should only introspect output tensors to prevent an INTERNAL ASSERT
591+
// FAILED in RecordFunction when we try to read input in RecordFunction exit
592+
// methods.
535593
auto meta = saveNcclMeta(fn, SaveNcclMetaConfig(false, true, false, true));
536594

537595
auto addAttr =
@@ -577,7 +635,8 @@ static void recordOperatorStart(
577635
{
578636
const std::lock_guard<std::recursive_mutex> lock(ob.gMutex);
579637

580-
// if current thread stack is empty, push the root node to the stack first
638+
// if current thread stack is empty, push the root node to the stack
639+
// first
581640
if (ob.opStack[tid].empty()) {
582641
auto thread_node_id = ob.getNewID();
583642
ob.opStack[tid].push(thread_node_id);
@@ -605,10 +664,15 @@ static void recordOperatorStart(
605664
const auto inputs = fn.inputs();
606665
// need to account for Stack mode where the inputs are at the end.
607666
size_t input_start = inputs.size() - num_inputs;
608-
667+
// tensor_index is the index of the flattened tensor list for all input
668+
// tensors
669+
int tensor_index = 0;
609670
for (const auto i : c10::irange(input_start, inputs.size())) {
610671
appendValueInfo(
611672
ob,
673+
tensor_index,
674+
fc.tensor_index_min_max_map,
675+
true,
612676
inputs[i],
613677
fc.inputShapes,
614678
fc.inputStrides,
@@ -623,8 +687,8 @@ static void recordOperatorStart(
623687

624688
fc.parentId = ob.opStack[tid].top();
625689
// get parent id from the forward stack, this can be different for
626-
// autograd ops, which may execute on a different thread than the original
627-
// thread (which should have the parent op on the stack).
690+
// autograd ops, which may execute on a different thread than the
691+
// original thread (which should have the parent op on the stack).
628692
auto fw_tid = fn.forwardThreadId();
629693
if (fw_tid != 0) {
630694
fc.fwParentId = ob.opStack[fw_tid].top();
@@ -706,9 +770,13 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
706770
std::vector<std::string> output_shapes;
707771
std::vector<std::string> output_values;
708772
try {
773+
int tensor_index = 0;
709774
for (const auto i : c10::irange(output_start, outputs.size())) {
710775
appendValueInfo(
711776
*ob,
777+
tensor_index,
778+
fc.tensor_index_min_max_map,
779+
false,
712780
outputs.at(i),
713781
output_shapes,
714782
output_strides,
@@ -752,6 +820,7 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
752820
op_schema_str,
753821
fc.kernelBackend,
754822
fc.kernelFile,
823+
fc.get_string_for_tensor_range(),
755824
additiona_attrs);
756825
ob->out << ",";
757826
}
@@ -762,8 +831,8 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
762831
}
763832
}
764833

765-
// Add execution trace observer callback functions to the RecordFunction global
766-
// observers.
834+
// Add execution trace observer callback functions to the RecordFunction
835+
// global observers.
767836
bool addExecutionTraceObserver(const std::string& output_file_path) {
768837
// Check if the observer is already initialized.
769838
if (ObserverManager::get() == nullptr) {
@@ -776,6 +845,13 @@ bool addExecutionTraceObserver(const std::string& output_file_path) {
776845
return false;
777846
}
778847

848+
// check if the environment variable is set to force recording integer
849+
// tensors
850+
auto env_variable =
851+
getenv("ENABLE_PYTORCH_EXECUTION_TRACE_INTEGRAL_TENSOR_RANGE");
852+
if (env_variable != nullptr) {
853+
ob.record_integral_tensor_range = true;
854+
}
779855
ob.cbHandle = addGlobalCallback(
780856
RecordFunctionCallback(&onFunctionEnter, &onFunctionExit)
781857
.needsInputs(true)

0 commit comments

Comments
 (0)