@@ -153,6 +153,8 @@ struct TORCH_API ExecutionTraceObserver { // NOLINT
153
153
state_ = newState;
154
154
}
155
155
156
+ bool record_integral_tensor_range{false };
157
+
156
158
private:
157
159
static bool callbackShouldBeEnabled (RunState run_state) {
158
160
return run_state == ExecutionTraceObserver::RunState::enabled;
@@ -189,6 +191,28 @@ struct FunctionCallContext : public ObserverContext { // NOLINT
189
191
std::vector<std::string> inputShapes;
190
192
std::vector<std::string> inputStrides;
191
193
std::vector<std::string> inputValues;
194
+ std::map<int , std::pair<long , long >> tensor_index_min_max_map;
195
+
196
+ std::string get_string_for_tensor_range () {
197
+ if (tensor_index_min_max_map.empty ()) {
198
+ return " " ;
199
+ }
200
+
201
+ std::string result = " {" ;
202
+ unsigned int i = 0 ;
203
+ for (auto const & [key, value] : tensor_index_min_max_map) {
204
+ if (i == tensor_index_min_max_map.size () - 1 ) {
205
+ result += json_str_escape (
206
+ fmt::format (" \" {}\" :[{},{}]" , key, value.first , value.second ));
207
+ } else {
208
+ result += json_str_escape (
209
+ fmt::format (" \" {}\" :[{},{}]," , key, value.first , value.second ));
210
+ }
211
+ i++;
212
+ }
213
+ result += " }" ;
214
+ return result;
215
+ }
192
216
};
193
217
194
218
// Opens the json file to write the execution trace.
@@ -240,14 +264,15 @@ static void writeJsonNode(
240
264
const std::string& operator_schema = " " ,
241
265
const std::string& kernelBackend = " " ,
242
266
const std::string& kernelFile = " " ,
267
+ const std::string& tensor_range = " " ,
243
268
const std::string& additiona_attrs = " " ) {
244
269
out << fmt::format (
245
270
R"JSON(
246
271
{{
247
272
"id": {}, "name": "{}", "ctrl_deps": {},
248
273
"inputs": {{"values": {}, "shapes": {}, "types": {}, "strides": {}}},
249
274
"outputs": {{"values": {}, "shapes": {}, "types": {}, "strides": {}}},
250
- "attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}}{}]
275
+ "attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}},{{"name": "tensor_range", "type": "string", "value": "{}"}} {}]
251
276
}})JSON" ,
252
277
id,
253
278
name,
@@ -269,6 +294,7 @@ static void writeJsonNode(
269
294
operator_schema,
270
295
kernelBackend,
271
296
kernelFile,
297
+ tensor_range,
272
298
additiona_attrs);
273
299
}
274
300
@@ -354,6 +380,9 @@ static ExecutionTraceObserver::ID getObjectID(
354
380
static std::tuple<std::string, std::string, std::string, std::string>
355
381
convertIValue (
356
382
ExecutionTraceObserver& ob,
383
+ int & tensorIndex,
384
+ std::map<int , std::pair<long , long >>& tensor_index_min_max_map,
385
+ bool isInput,
357
386
const c10::IValue& val,
358
387
const bool baseType = true ,
359
388
const size_t maxArrayLen = kMaxNumElements ) {
@@ -391,7 +420,18 @@ convertIValue(
391
420
numel = tensor_impl->numel ();
392
421
itemsize = tensor_impl->itemsize ();
393
422
device_str = tensor_impl->device ().str ();
423
+
424
+ if (ob.record_integral_tensor_range && isInput &&
425
+ at::isIntegralType (tensor.scalar_type (), false ) &&
426
+ tensor.numel () != 0 ) {
427
+ enableRecordFunction (false );
428
+ long min = tensor.min ().item ().toLong ();
429
+ long max = tensor.max ().item ().toLong ();
430
+ enableRecordFunction (true );
431
+ tensor_index_min_max_map[tensorIndex] = std::make_pair (min, max);
432
+ }
394
433
}
434
+ tensorIndex++;
395
435
tensor_value = fmt::format (
396
436
" [{},{},{},{},{},\" {}\" ]" ,
397
437
tensor_id,
@@ -410,7 +450,14 @@ convertIValue(
410
450
std::vector<std::string> type_array;
411
451
std::vector<std::string> value_array;
412
452
for (const auto j : c10::irange (tuple_size)) {
413
- auto tuple = convertIValue (ob, val_tuple[j], false , maxArrayLen);
453
+ auto tuple = convertIValue (
454
+ ob,
455
+ tensorIndex,
456
+ tensor_index_min_max_map,
457
+ isInput,
458
+ val_tuple[j],
459
+ false ,
460
+ maxArrayLen);
414
461
shape_array.push_back (std::get<0 >(tuple));
415
462
stride_array.push_back (std::get<1 >(tuple));
416
463
type_array.push_back (std::get<2 >(tuple));
@@ -431,7 +478,14 @@ convertIValue(
431
478
std::vector<std::string> type_array;
432
479
std::vector<std::string> value_array;
433
480
for (const auto j : c10::irange (list_size)) {
434
- auto tuple = convertIValue (ob, val_list.get (j), false , maxArrayLen);
481
+ auto tuple = convertIValue (
482
+ ob,
483
+ tensorIndex,
484
+ tensor_index_min_max_map,
485
+ isInput,
486
+ val_list.get (j),
487
+ false ,
488
+ maxArrayLen);
435
489
shape_array.push_back (std::get<0 >(tuple));
436
490
stride_array.push_back (std::get<1 >(tuple));
437
491
type_array.push_back (std::get<2 >(tuple));
@@ -462,13 +516,16 @@ convertIValue(
462
516
463
517
static void appendValueInfo (
464
518
ExecutionTraceObserver& ob,
519
+ int & tensorIndex,
520
+ std::map<int , std::pair<long , long >>& tensor_index_min_max_map,
521
+ bool isInput,
465
522
const c10::IValue& val,
466
523
std::vector<std::string>& shapes,
467
524
std::vector<std::string>& strides,
468
525
std::vector<std::string>& types,
469
526
std::vector<std::string>& values) {
470
- auto tuple = convertIValue (ob, val, true );
471
-
527
+ auto tuple = convertIValue (
528
+ ob, tensorIndex, tensor_index_min_max_map, isInput, val, true );
472
529
shapes.push_back (std::get<0 >(tuple));
473
530
strides.push_back (std::get<1 >(tuple));
474
531
types.push_back (std::get<2 >(tuple));
@@ -529,9 +586,10 @@ inline std::string getCommsNodeAttrs(const RecordFunction& fn) { // NOLINT
529
586
}
530
587
531
588
// get NcclMeta from record function, this used ParamCommsDebugInfo above
532
- // since we currently have this read called in onFunctionExit flow, we should
533
- // only introspect output tensors to prevent an INTERNAL ASSERT FAILED in
534
- // RecordFunction when we try to read input in RecordFunction exit methods.
589
+ // since we currently have this read called in onFunctionExit flow, we
590
+ // should only introspect output tensors to prevent an INTERNAL ASSERT
591
+ // FAILED in RecordFunction when we try to read input in RecordFunction exit
592
+ // methods.
535
593
auto meta = saveNcclMeta (fn, SaveNcclMetaConfig (false , true , false , true ));
536
594
537
595
auto addAttr =
@@ -577,7 +635,8 @@ static void recordOperatorStart(
577
635
{
578
636
const std::lock_guard<std::recursive_mutex> lock (ob.gMutex );
579
637
580
- // if current thread stack is empty, push the root node to the stack first
638
+ // if current thread stack is empty, push the root node to the stack
639
+ // first
581
640
if (ob.opStack [tid].empty ()) {
582
641
auto thread_node_id = ob.getNewID ();
583
642
ob.opStack [tid].push (thread_node_id);
@@ -605,10 +664,15 @@ static void recordOperatorStart(
605
664
const auto inputs = fn.inputs ();
606
665
// need to account for Stack mode where the inputs are at the end.
607
666
size_t input_start = inputs.size () - num_inputs;
608
-
667
+ // tensor_index is the index of the flattened tensor list for all input
668
+ // tensors
669
+ int tensor_index = 0 ;
609
670
for (const auto i : c10::irange (input_start, inputs.size ())) {
610
671
appendValueInfo (
611
672
ob,
673
+ tensor_index,
674
+ fc.tensor_index_min_max_map ,
675
+ true ,
612
676
inputs[i],
613
677
fc.inputShapes ,
614
678
fc.inputStrides ,
@@ -623,8 +687,8 @@ static void recordOperatorStart(
623
687
624
688
fc.parentId = ob.opStack [tid].top ();
625
689
// get parent id from the forward stack, this can be different for
626
- // autograd ops, which may execute on a different thread than the original
627
- // thread (which should have the parent op on the stack).
690
+ // autograd ops, which may execute on a different thread than the
691
+ // original thread (which should have the parent op on the stack).
628
692
auto fw_tid = fn.forwardThreadId ();
629
693
if (fw_tid != 0 ) {
630
694
fc.fwParentId = ob.opStack [fw_tid].top ();
@@ -706,9 +770,13 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
706
770
std::vector<std::string> output_shapes;
707
771
std::vector<std::string> output_values;
708
772
try {
773
+ int tensor_index = 0 ;
709
774
for (const auto i : c10::irange (output_start, outputs.size ())) {
710
775
appendValueInfo (
711
776
*ob,
777
+ tensor_index,
778
+ fc.tensor_index_min_max_map ,
779
+ false ,
712
780
outputs.at (i),
713
781
output_shapes,
714
782
output_strides,
@@ -752,6 +820,7 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
752
820
op_schema_str,
753
821
fc.kernelBackend ,
754
822
fc.kernelFile ,
823
+ fc.get_string_for_tensor_range (),
755
824
additiona_attrs);
756
825
ob->out << " ," ;
757
826
}
@@ -762,8 +831,8 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
762
831
}
763
832
}
764
833
765
- // Add execution trace observer callback functions to the RecordFunction global
766
- // observers.
834
+ // Add execution trace observer callback functions to the RecordFunction
835
+ // global observers.
767
836
bool addExecutionTraceObserver (const std::string& output_file_path) {
768
837
// Check if the observer is already initialized.
769
838
if (ObserverManager::get () == nullptr ) {
@@ -776,6 +845,13 @@ bool addExecutionTraceObserver(const std::string& output_file_path) {
776
845
return false ;
777
846
}
778
847
848
+ // check if the environment variable is set to force recording integer
849
+ // tensors
850
+ auto env_variable =
851
+ getenv (" ENABLE_PYTORCH_EXECUTION_TRACE_INTEGRAL_TENSOR_RANGE" );
852
+ if (env_variable != nullptr ) {
853
+ ob.record_integral_tensor_range = true ;
854
+ }
779
855
ob.cbHandle = addGlobalCallback (
780
856
RecordFunctionCallback (&onFunctionEnter, &onFunctionExit)
781
857
.needsInputs (true )
0 commit comments