Skip to content

Commit

Permalink
Bug fix for modeling utilities function: apply_chunking_to_forward, c…
Browse files Browse the repository at this point in the history
…hunking should be in the chunking dimension, an exception was raised if the complete shape of the inputs was not the same rather than only the chunking dimension (huggingface#8391)

Co-authored-by: pedro <pe25171@mit.edu>
  • Loading branch information
2 people authored and fabiocapsouza committed Nov 15, 2020
1 parent 4a12398 commit 50a1ead
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1668,9 +1668,9 @@ def forward(self, hidden_states):
"""

assert len(input_tensors) > 0, "{} has to be a tuple/list of tensors".format(input_tensors)
tensor_shape = input_tensors[0].shape
tensor_shape = input_tensors[0].shape[chunk_dim]
assert all(
input_tensor.shape == tensor_shape for input_tensor in input_tensors
input_tensor.shape[chunk_dim] == tensor_shape for input_tensor in input_tensors
), "All input tenors have to be of the same shape"

# inspect.signature exist since python 3.5 and is a python method -> no problem with backward compatibility
Expand Down

0 comments on commit 50a1ead

Please sign in to comment.