Skip to content

Commit a8851d7

Browse files
authored
fix qwen_vl and load_pretrained patch (#2190)
1 parent b45dbf6 commit a8851d7

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

mindnlp/transformers/modeling_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def wrapper(
184184
pretrained_model_name_or_path,
185185
**kwargs,
186186
):
187-
device_map = kwargs.pop("device_map", None)
188-
sharded_metadata = kwargs.pop("sharded_metadata", None)
187+
device_map = kwargs.get("device_map", None)
188+
sharded_metadata = kwargs.get("sharded_metadata", None)
189189

190190
# if device_map is not None and not initialize distribute module, raise Error.
191191
if device_map is not None:

mindtorch/nn/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1220,7 +1220,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.
12201220

12211221
attn_weight = query @ key.transpose(-2, -1) * scale_factor
12221222
attn_weight += attn_bias
1223-
attn_weight = softmax(attn_weight, dim=-1, dtype=mindtorch.float32).to(query.dtype)
1223+
attn_weight = softmax(attn_weight, dim=-1)
12241224
attn_weight = dropout(attn_weight, dropout_p, training=True)
12251225
return attn_weight @ value
12261226

0 commit comments

Comments
 (0)