diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 7c98a2d0d49dd5..62e89c824e6466 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -139,7 +139,7 @@ def symbolic(g, self, mask, dim): r_mask = g.op( "Cast", g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), - to_i=sym_help.cast_pytorch_to_onnx["Byte"], + to_i=sym_help.cast_pytorch_to_onnx["Bool"], ) output = masked_fill( g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) @@ -420,7 +420,6 @@ def get_attention_mask(self, attention_mask): if attention_mask.dim() <= 2: extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) - attention_mask = attention_mask.byte() elif attention_mask.dim() == 3: attention_mask = attention_mask.unsqueeze(1) @@ -614,7 +613,7 @@ def forward( Input states to the module usually the output from previous layer, it will be the Q,K and V in *Attention(Q,K,V)* - attention_mask (`torch.ByteTensor`): + attention_mask (`torch.BoolTensor`): An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* th token. diff --git a/src/transformers/models/deberta_v2/modeling_deberta_v2.py b/src/transformers/models/deberta_v2/modeling_deberta_v2.py index cd6e6318e3d426..95a288657e4838 100644 --- a/src/transformers/models/deberta_v2/modeling_deberta_v2.py +++ b/src/transformers/models/deberta_v2/modeling_deberta_v2.py @@ -130,7 +130,7 @@ def symbolic(g, self, mask, dim): r_mask = g.op( "Cast", g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), - to_i=sym_help.cast_pytorch_to_onnx["Byte"], + to_i=sym_help.cast_pytorch_to_onnx["Bool"], ) output = masked_fill( g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) @@ -453,7 +453,6 @@ def get_attention_mask(self, attention_mask): if attention_mask.dim() <= 2: extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) - attention_mask = attention_mask.byte() elif attention_mask.dim() == 3: attention_mask = attention_mask.unsqueeze(1) @@ -484,7 +483,7 @@ def forward( if attention_mask.dim() <= 2: input_mask = attention_mask else: - input_mask = (attention_mask.sum(-2) > 0).byte() + input_mask = attention_mask.sum(-2) > 0 attention_mask = self.get_attention_mask(attention_mask) relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos) @@ -687,7 +686,7 @@ def forward( Input states to the module usually the output from previous layer, it will be the Q,K and V in *Attention(Q,K,V)* - attention_mask (`torch.ByteTensor`): + attention_mask (`torch.BoolTensor`): An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* th token. diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index 7cdce062eee335..417fd81c6e9d42 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -559,7 +559,7 @@ def symbolic(g, self, mask, dim): r_mask = g.op( "Cast", g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value), - to_i=sym_help.cast_pytorch_to_onnx["Byte"], + to_i=sym_help.cast_pytorch_to_onnx["Bool"], ) output = masked_fill( g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min)) @@ -754,7 +754,7 @@ def forward( Input states to the module usually the output from previous layer, it will be the Q,K and V in *Attention(Q,K,V)* - attention_mask (`torch.ByteTensor`): + attention_mask (`torch.BoolTensor`): An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j* th token. @@ -1086,7 +1086,6 @@ def get_attention_mask(self, attention_mask): if attention_mask.dim() <= 2: extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1) - attention_mask = attention_mask.byte() elif attention_mask.dim() == 3: attention_mask = attention_mask.unsqueeze(1) @@ -1117,7 +1116,7 @@ def forward( if attention_mask.dim() <= 2: input_mask = attention_mask else: - input_mask = (attention_mask.sum(-2) > 0).byte() + input_mask = attention_mask.sum(-2) > 0 attention_mask = self.get_attention_mask(attention_mask) relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)