Skip to content

Commit 7669b5c

Browse files
authored
fix transformers ut for 4.54.1 (#2117)
1 parent 6c855e8 commit 7669b5c

File tree

5 files changed

+22
-2
lines changed

5 files changed

+22
-2
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,4 +169,5 @@ flagged/
169169

170170
tests/diffusers/
171171
tests/transformers/
172+
tests/huggingface_transformers/
172173
.gradio/

mindnlp/core/_tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,15 @@ def record_stream(self, stream):
815815
Tensor.scatter = ops.scatter
816816
StubTensor.scatter = ops.scatter
817817

818+
Tensor.mul = ops.mul
819+
StubTensor.mul = ops.mul
820+
821+
Tensor.index_select = ops.index_select
822+
StubTensor.index_select = ops.index_select
823+
824+
Tensor.gather = ops.gather
825+
StubTensor.gather = ops.gather
826+
818827
def _rebuild_from_type_v2(func, new_type, args, state):
819828
ret = func(*args)
820829
return ret

mindnlp/core/npu/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,9 @@ def __enter__(self):
5858

5959
def __exit__(self, type: Any, value: Any, traceback: Any):
6060
return False
61+
62+
def mem_get_info(index):
63+
return (1024, 1024)
64+
65+
def current_device():
66+
return core.device('npu', 0)

mindnlp/core/ops/inplace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def inplace_random(self, from_=0, to=None, *, generator=None):
209209
generator = default_generator
210210
seed, offset = generator._step( # pylint: disable=protected-access
211211
generator_step_)
212-
return inplace_random_op(input, from_, to, seed, offset)
212+
return inplace_random_op(self, from_, to, seed, offset)
213213
else:
214214
if isinstance(self.dtype, typing.Float):
215215
self.uniform_(from_, to, generator=generator)

mindnlp/core/types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,11 @@ def __init__(self, type=None, index=None):
3838
_id = type.index
3939
elif isinstance(type, int):
4040
_id = type
41-
_target = DEVICE_MAP[mindspore.get_current_device().device_target]
41+
try:
42+
device_target = mindspore.get_current_device().device_target
43+
except:
44+
device_target = mindspore.get_context('device_target')
45+
_target = DEVICE_MAP[device_target]
4246
else:
4347
print(type)
4448
raise TypeError("core.device(): `type` must be type of 'str' or 'core.device'.")

0 commit comments

Comments
 (0)