Skip to content

Commit 568e398

Browse files
authored
fix c class models on OrangePi (#2213)
1 parent 2dfb9ad commit 568e398

File tree

4 files changed

+16
-9
lines changed

4 files changed

+16
-9
lines changed

mindtorch/_apis/gpu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -848,7 +848,7 @@ def addcmul(input, tensor1, tensor2, value=1.0):
848848
return legacy.addcmul(input, tensor1, tensor2, mindspore.Tensor(value))
849849

850850
def addmm(input, mat1, mat2, alpha=1.0, beta=1.0):
851-
return add(mul(input, beta), mul(bmm(mat1, mat2), alpha))
851+
return add(mul(input, beta), mul(matmul(mat1, mat2), alpha))
852852

853853
def im2col(input, kernel_size, dilation=1, padding=0, stride=1):
854854
out = legacy.im2_col(input, kernel_size, stride, dilation, padding)

mindtorch/_apis/npu.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import mindspore
2+
import mindtorch
23
from mindspore._c_expression import _empty_instance
34
from ..configs import use_pyboost, ON_A1, ON_ORANGE_PI
45
from .._op_prim.ascend import legacy, pyboost
@@ -824,9 +825,11 @@ def argmax(input, axis, keepdims):
824825
return legacy.argmax(input, axis, keepdims)
825826

826827
def argmin(input, axis, keepdims):
827-
if use_pyboost():
828+
if use_pyboost() and not ON_ORANGE_PI:
828829
return pyboost.argmin_ext_op(input, axis, keepdims)
829-
return legacy.argmin(input, axis, keepdims)
830+
if axis is None:
831+
axis = -1
832+
return legacy.arg_min_with_value(input, axis, keepdims)[0]
830833

831834

832835
def bmm(input, other):
@@ -1136,7 +1139,7 @@ def masked_scatter(input, mask, value):
11361139
return legacy.masked_scatter(input, mask, value)
11371140

11381141
def neg(input):
1139-
if use_pyboost():
1142+
if use_pyboost() and not ON_ORANGE_PI:
11401143
return pyboost.neg_op(input)
11411144
return legacy.neg(input)
11421145

@@ -1557,7 +1560,7 @@ def inplace_exponential(self, lambd, generator):
15571560
return legacy.expo(self, lambd, generator)
15581561

15591562
def im2col(input, kernel_size, dilation=1, padding=0, stride=1):
1560-
if use_pyboost() and not ON_A1:
1563+
if use_pyboost() and not ON_A1 and not ON_ORANGE_PI:
15611564
return pyboost.im2col_ext_op(input, kernel_size, dilation, padding, stride)
15621565
out = legacy.im2_col(input, kernel_size, stride, dilation, padding)
15631566
out_shape = out.shape[:1] + (-1,) + out.shape[-1:]
@@ -1570,9 +1573,10 @@ def upsample_nearest2d(input, output_size, scale_factors):
15701573
return legacy.upsample_nearest2d(input, scale_factor, align_corners)
15711574

15721575
def addmm(input, mat1, mat2, alpha=1.0, beta=1.0):
1573-
if use_pyboost():
1576+
if use_pyboost() and not ON_ORANGE_PI:
15741577
return pyboost.addmm_op(input, mat1, mat2, alpha, beta)
1575-
return legacy.addmm(input, mat1, mat2, alpha, beta)
1578+
return add(mul(input, beta), mul(matmul(mat1, mat2), alpha))
1579+
15761580

15771581
def meshgrid(input, lambd):
15781582
if use_pyboost():

mindtorch/ops/array.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -622,6 +622,9 @@ def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
622622
self_viewed = self
623623
self_viewed_shape = list(self.shape)
624624
dim = 0
625+
if ON_ORANGE_PI:
626+
if all([isinstance(index, slice) for index in indexes]):
627+
return getitem(self_viewed, tuple(indexes)), remain_indexes
625628
for i, index in enumerate(indexes):
626629
if isinstance(index, (list, tuple, np.ndarray)):
627630
index_np = np.array(index) if isinstance(index, (list, tuple)) else index
@@ -634,7 +637,6 @@ def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
634637
raise TypeError(f"Index {index} contain unsupported elements")
635638
self_viewed, dim, remain_indexes, self_viewed_shape = _process_dim_in_multi_dim_index(
636639
self_viewed, self, index, dim, indexed_dims, i, remain_indexes, self_viewed_shape)
637-
638640
return self_viewed, remain_indexes
639641

640642

tests/run_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import mindnlp
77
from mindnlp import transformers
88

9-
mindspore.set_context(pynative_synchronize=True)
9+
# mindspore.set_context(pynative_synchronize=True)
10+
mindspore.runtime.launch_blocking()
1011

1112
def run_tests():
1213
"""

0 commit comments

Comments
 (0)