Skip to content

Commit 3fbc94b

Browse files
authored
speedup sdpa_mask for mindspore (#2113)
1 parent 7fda79a commit 3fbc94b

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

mindnlp/core/ops/creation.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -183,24 +183,24 @@ def empty(*size, dtype=None, device=None, requires_grad=False, pin_memory=False,
183183
if dtype is None:
184184
dtype = get_default_dtype()
185185

186-
if device:
187-
if not isinstance(device, str) and hasattr(device, "type"):
188-
device = device.type
189-
if device.lower() == 'cpu':
190-
device = 'CPU'
191-
elif device.lower() == 'npu':
192-
device = 'Ascend'
193-
elif device.lower() == 'cuda':
194-
device = 'GPU'
195-
else:
196-
device = 'meta'
197-
198-
# To avoid the problem in irecv and recv of using empty.
199-
if device not in ['meta', 'GPU']:
200-
out = mindspore.mint.empty(size, dtype=dtype, device=device)
201-
else:
202-
out = CTensor(dtype=dtype, shape=size)
203-
out = mindspore.Tensor(out)
186+
# if device:
187+
# if not isinstance(device, str) and hasattr(device, "type"):
188+
# device = device.type
189+
# if device.lower() == 'cpu':
190+
# device = 'CPU'
191+
# elif device.lower() == 'npu':
192+
# device = 'Ascend'
193+
# elif device.lower() == 'cuda':
194+
# device = 'GPU'
195+
# else:
196+
# device = 'meta'
197+
198+
# # To avoid the problem in irecv and recv of using empty.
199+
# if device not in ['meta', 'GPU']:
200+
# out = mindspore.mint.empty(size, dtype=dtype, device=device)
201+
# else:
202+
out = CTensor(dtype=dtype, shape=size)
203+
out = mindspore.Tensor(out)
204204
# else:
205205
# out = np.empty(size, dtype=dtype2np[dtype])
206206
# out = mindspore.Tensor(out)

mindnlp/transformers/masking_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _ignore_causal_mask_sdpa(
198198
allowing to dispatch to the flash attention kernel (that can otherwise not be used if a custom `attn_mask` is
199199
passed).
200200
"""
201-
is_tracing = core.jit.is_tracing() or isinstance(padding_mask, core.fx.Proxy) or is_torchdynamo_compiling()
201+
is_tracing = core.jit.is_tracing() or isinstance(padding_mask, core.fx.Proxy)
202202
if padding_mask is not None and padding_mask.shape[-1] > kv_length:
203203
mask_indices = core.arange(kv_length, device=padding_mask.device)
204204
mask_indices += kv_offset

0 commit comments

Comments
 (0)