Skip to content

Commit

Permalink
fix: use bool instead of uint8/byte in Deberta/DebertaV2/SEW-D to mak…
Browse files Browse the repository at this point in the history
…e it compatible with TensorRT (huggingface#23683)

* Use bool instead of uint8/byte in DebertaV2 to make it compatible with TensorRT

TensorRT cannot accept onnx graph with uint8/byte intermediate tensors. This PR uses bool tensors instead of unit8/byte tensors to make the exported onnx file can work with TensorRT.

* fix: use bool instead of uint8/byte in Deberta and SEW-D

---------

Co-authored-by: Yuxian Qiu <yuxianq@nvidia.com>
  • Loading branch information
2 people authored and sheonhan committed Jun 1, 2023
1 parent 76da442 commit 135a35f
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 11 deletions.
5 changes: 2 additions & 3 deletions src/transformers/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def symbolic(g, self, mask, dim):
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
to_i=sym_help.cast_pytorch_to_onnx["Bool"],
)
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
Expand Down Expand Up @@ -420,7 +420,6 @@ def get_attention_mask(self, attention_mask):
if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)

Expand Down Expand Up @@ -614,7 +613,7 @@ def forward(
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*
attention_mask (`torch.ByteTensor`):
attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def symbolic(g, self, mask, dim):
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
to_i=sym_help.cast_pytorch_to_onnx["Bool"],
)
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
Expand Down Expand Up @@ -453,7 +453,6 @@ def get_attention_mask(self, attention_mask):
if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)

Expand Down Expand Up @@ -484,7 +483,7 @@ def forward(
if attention_mask.dim() <= 2:
input_mask = attention_mask
else:
input_mask = (attention_mask.sum(-2) > 0).byte()
input_mask = attention_mask.sum(-2) > 0
attention_mask = self.get_attention_mask(attention_mask)
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)

Expand Down Expand Up @@ -687,7 +686,7 @@ def forward(
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*
attention_mask (`torch.ByteTensor`):
attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
Expand Down
7 changes: 3 additions & 4 deletions src/transformers/models/sew_d/modeling_sew_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def symbolic(g, self, mask, dim):
r_mask = g.op(
"Cast",
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
to_i=sym_help.cast_pytorch_to_onnx["Bool"],
)
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
Expand Down Expand Up @@ -754,7 +754,7 @@ def forward(
Input states to the module usually the output from previous layer, it will be the Q,K and V in
*Attention(Q,K,V)*
attention_mask (`torch.ByteTensor`):
attention_mask (`torch.BoolTensor`):
An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
th token.
Expand Down Expand Up @@ -1086,7 +1086,6 @@ def get_attention_mask(self, attention_mask):
if attention_mask.dim() <= 2:
extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
attention_mask = attention_mask.byte()
elif attention_mask.dim() == 3:
attention_mask = attention_mask.unsqueeze(1)

Expand Down Expand Up @@ -1117,7 +1116,7 @@ def forward(
if attention_mask.dim() <= 2:
input_mask = attention_mask
else:
input_mask = (attention_mask.sum(-2) > 0).byte()
input_mask = attention_mask.sum(-2) > 0
attention_mask = self.get_attention_mask(attention_mask)
relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)

Expand Down

0 comments on commit 135a35f

Please sign in to comment.