Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't try to bind unused inputs in the Training frontend #6166

Merged
merged 11 commits into from
Dec 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions orttraining/orttraining/python/training/orttrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,15 +804,20 @@ def _training_session_run_helper(self, is_train, inputs, inputs_desc, outputs_de
else:
iobinding = self._eval_io_binding

# Get the list of the actual session inputs because unused inputs can be removed.
input_nodes = self._training_session.get_inputs()
input_node_names = [input_node.name for input_node in input_nodes]

# Bind input tensors
for input, input_desc in zip(inputs, inputs_desc):
device_index = _utils.get_device_index_from_input(input)
iobinding.bind_input(input_desc.name,
input.device.type,
device_index,
_utils.dtype_torch_to_numpy(input.dtype),
list(input.size()),
input.data_ptr())
if input_desc.name in input_node_names:
device_index = _utils.get_device_index_from_input(input)
iobinding.bind_input(input_desc.name,
input.device.type,
device_index,
_utils.dtype_torch_to_numpy(input.dtype),
list(input.size()),
input.data_ptr())

# Bind output tensors
outputs_desc_resolved = self._resolve_symbolic_dimensions(inputs, inputs_desc, outputs_desc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1428,3 +1428,21 @@ def testORTTrainerOptionsDisabledAdasumFlag(test_input):

actual_values = orttrainer_options.ORTTrainerOptions(test_input)
assert actual_values.distributed.enable_adasum == False

def testORTTrainerUnusedInput():
class UnusedInputModel(torch.nn.Module):
def __init__(self):
super(UnusedInputModel, self).__init__()
def forward(self, x, y):
return torch.mean(x)

model = UnusedInputModel()
model_desc = {'inputs': [('x', [1]), ('y', [1])], 'outputs': [('loss', [], True)]}
optim_config = optim.LambConfig(lr=0.001)
trainer = orttrainer.ORTTrainer(model, model_desc, optim_config)

# Run just one step to make sure there are no iobinding errors for the unused input.
try:
trainer.train_step(torch.FloatTensor([1.0]), torch.FloatTensor([1.0]))
except RuntimeError:
pytest.fail("RuntimeError doing train_step with an unused input.")