Skip to content

Commit

Permalink
Bug fix for modeling utilities function: apply_chunking_to_forward, c…
Browse files Browse the repository at this point in the history
…hunking should be in the chunking dimension, an exception was raised if the complete shape of the inputs was not the same rather than only the chunking dimension (#8391)

Co-authored-by: pedro <pe25171@mit.edu>
  • Loading branch information
pedrocolon93 and pedro authored Nov 10, 2020
1 parent 70708cc commit eb3bd73
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,9 +1668,9 @@ def forward(self, hidden_states):
"""

assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
tensor_shape = input_tensors[0].shape
tensor_shape = input_tensors[0].shape[chunk_dim]
assert all(
input_tensor.shape == tensor_shape for input_tensor in input_tensors
input_tensor.shape[chunk_dim] == tensor_shape for input_tensor in input_tensors
), "All input tenors have to be of the same shape"

# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
Expand Down

0 comments on commit eb3bd73

Please sign in to comment.