Skip to content

Commit 78fcde9

Browse files
smessmerfacebook-github-bot
authored andcommitted
Trace scattered tensor options arguments (pytorch#44071)
Summary: Pull Request resolved: pytorch#44071 Previously, tracing re-gathered ScalarType, Layout, Device, bool into a TensorOptions object and called `tracer::addInput()` on the gathered TensorOptions argument. `tracer::addInput()` then scattered them again and added the individual scattered arguments to the traced graph. This PR avoids the extraneous gathering and re-scattering step and calls `tracer::addInput()` on the individual arguments directly. This avoid the perf hit for an unnecessary gathering step. This applies to both c10-full and non-c10-full ops. In the case of c10-full ops, the tracing kernels takes scattered arguments and we can directly pass them to `tracer::addInput()`. In the case of non-c10-full ops, the kernel takes a `TensorOptions` argument but we still call `tracer::addInput()` on the scattered arguments. ghstack-source-id: 112825793 Test Plan: waitforsandcastle vs master: https://www.internalfb.com/intern/fblearner/details/216129483/ vs previous diff: https://www.internalfb.com/intern/fblearner/details/216170069/ Reviewed By: ezyang Differential Revision: D23486638 fbshipit-source-id: e0b53e6673cef8d7f94158e718301eee261e5d22
1 parent 2ac7de7 commit 78fcde9

File tree

4 files changed

+35
-30
lines changed

4 files changed

+35
-30
lines changed

c10/core/TensorOptions.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ inline Device device_or_default(c10::optional<Device> device) {
3636
return device.has_value() ? *device : Device(kCPU);
3737
}
3838

39+
inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
40+
return pinned_memory.has_value() ? *pinned_memory : false;
41+
}
42+
3943
/// A class to encapsulate construction axes of an Tensor. TensorOptions was
4044
/// designed to support the Python style API for specifying construction options
4145
/// on factory functions, e.g.,
@@ -317,7 +321,7 @@ struct C10_API TensorOptions {
317321

318322
/// Returns the `pinned_memory` property of the `TensorOptions`.
319323
bool pinned_memory() const noexcept {
320-
return has_pinned_memory_ ? pinned_memory_ : false;
324+
return pinned_memory_or_default(pinned_memory_opt());
321325
}
322326

323327
/// Returns whether the `pinned_memory` is specified.

tools/autograd/gen_variable_type.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,18 @@
7171
# arguments (inside of the `native_functions.yaml`)
7272
RENAME_TRACE_ADD_ARGS = {
7373
'fill': '''\
74-
jit::tracer::addInputs(node, "options", TensorOptions());
74+
jit::tracer::addInputs(node, "options", c10::optional<ScalarType>());
75+
jit::tracer::addInputs(node, "options", layout_or_default(c10::nullopt));
76+
jit::tracer::addInputs(node, "options", device_or_default(c10::nullopt));
77+
jit::tracer::addInputs(node, "options", pinned_memory_or_default(c10::nullopt));
7578
c10::optional<MemoryFormat> memory_format = c10::MemoryFormat::Preserve;
7679
jit::tracer::addInputs(node, "memory_format", memory_format);
7780
''',
7881
'zero': '''\
79-
jit::tracer::addInputs(node, "options", TensorOptions());
82+
jit::tracer::addInputs(node, "options", c10::optional<ScalarType>());
83+
jit::tracer::addInputs(node, "options", layout_or_default(c10::nullopt));
84+
jit::tracer::addInputs(node, "options", device_or_default(c10::nullopt));
85+
jit::tracer::addInputs(node, "options", pinned_memory_or_default(c10::nullopt));
8086
c10::optional<MemoryFormat> memory_format = c10::MemoryFormat::Preserve;
8187
jit::tracer::addInputs(node, "memory_format", memory_format);
8288
''',
@@ -498,26 +504,35 @@ def format_trace_inputs(declaration):
498504

499505
def dispatch_trace_input(arg_spec):
500506
name, value, simple_type, nullable = arg_spec
501-
if declaration['use_c10_dispatcher'] == 'full':
502-
if value == "options":
503-
value = gather_tensor_options
504-
else:
505-
assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper'
506507
# XXX: For arg that have type of Tensor?[], tracer will pass allow_undefined to addInputs
507508
if simple_type == 'TensorList' and nullable:
508509
return '''jit::tracer::addInputs(node, "{}", {}, {});'''.format(name, value, "true")
509510
else:
510-
return ADD_TRACE_INPUT.substitute(name=name, input=value)
511+
if value == "options":
512+
result = ""
513+
result += ADD_TRACE_INPUT.substitute(name=name, input="optTypeMetaToScalarType(options.dtype_opt())") + "\n"
514+
result += ADD_TRACE_INPUT.substitute(name=name, input="options.layout()") + "\n"
515+
result += ADD_TRACE_INPUT.substitute(name=name, input="options.device()") + "\n"
516+
result += ADD_TRACE_INPUT.substitute(name=name, input="options.pinned_memory()")
517+
return result
518+
else:
519+
return ADD_TRACE_INPUT.substitute(name=name, input=value)
511520

512-
trace_inputs = declaration['arguments']
521+
if declaration['use_c10_dispatcher'] == 'full':
522+
trace_inputs = declaration['schema_order_arguments']
523+
else:
524+
trace_inputs = declaration['arguments']
513525

514526
if is_out_overload(declaration):
515527
# *_out functions take the result as a first argument, but they are the
516528
# last argument in the JIT schema.
517529
out_input = trace_inputs[0]
518530
trace_inputs = trace_inputs[1:]
519531

520-
trace_input_spec = [(i['name'], i['name'], i['simple_type'], i.get('is_nullable')) for i in trace_inputs]
532+
if declaration['use_c10_dispatcher'] == 'full':
533+
trace_input_spec = [(i['name'], i['name'], i['type'], i.get('is_nullable')) for i in trace_inputs]
534+
else:
535+
trace_input_spec = [(i['name'], i['name'], i['simple_type'], i.get('is_nullable')) for i in trace_inputs]
521536

522537
trace_inputs = \
523538
'\n'.join(dispatch_trace_input(arg_spec) for arg_spec in trace_input_spec)
@@ -526,11 +541,6 @@ def dispatch_trace_input(arg_spec):
526541
# for *_out functions, handle the result argument differently for inplace/outplace.
527542
# For inplace: just add the input to the end to confirm with the JIT schema
528543
value = out_input['name']
529-
if declaration['use_c10_dispatcher'] == 'full':
530-
if value == "options":
531-
value = gather_tensor_options
532-
else:
533-
assert declaration['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper'
534544
inplace = ADD_TRACE_INPUT.substitute(name=out_input['name'], input=value)
535545

536546
# for outplace: do nothing, except if the declaration is a factory.
@@ -539,7 +549,11 @@ def dispatch_trace_input(arg_spec):
539549
trace_name = uninplace_api_name(declaration['api_name'])
540550
has_factory_name = trace_name in FACTORY_FUNCTION_NAMES
541551
if has_factory_name:
542-
outplace = ADD_TRACE_INPUT.substitute(name='out', input='out.options()')
552+
outplace = ""
553+
outplace += ADD_TRACE_INPUT.substitute(name='out', input='optTypeMetaToScalarType(out.options().dtype_opt())') + "\n"
554+
outplace += ADD_TRACE_INPUT.substitute(name='out', input='out.options().layout()') + "\n"
555+
outplace += ADD_TRACE_INPUT.substitute(name='out', input='out.options().device()') + "\n"
556+
outplace += ADD_TRACE_INPUT.substitute(name='out', input='out.options().pinned_memory()')
543557
else:
544558
outplace = ''
545559

torch/csrc/jit/frontend/tracer.cpp

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -696,15 +696,6 @@ void addInputs(
696696
}
697697
}
698698

699-
void addInputs(Node* n, const char* name, const at::TensorOptions& options) {
700-
// [TensorOptions in script] - update this when you change how we schematize
701-
// TensorOptions
702-
addInputs(n, name, options.dtype_opt());
703-
addInputs(n, name, options.layout());
704-
addInputs(n, name, options.device());
705-
addInputs(n, name, options.pinned_memory());
706-
}
707-
708699
void addInputs(Node* n, const char* name, at::IntArrayRef value) {
709700
using ArgumentStash = jit::tracer::ArgumentStash;
710701
std::vector<Value*> info = ArgumentStash::hasIntArrayRef(name)

torch/csrc/jit/frontend/tracer.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,6 @@ TORCH_API void addInputs(
266266
Node* n,
267267
const char* name,
268268
const c10::optional<std::string>& value);
269-
TORCH_API void addInputs(
270-
Node* n,
271-
const char* name,
272-
const at::TensorOptions& value);
273269
TORCH_API void addInputs(Node* n, const char* name, at::Device value);
274270
TORCH_API void addInputs(Node* n, const char* name, at::Layout value);
275271
TORCH_API void addInputs(Node* n, const char* name, at::ScalarType value);

0 commit comments

Comments
 (0)