Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mindtorch/_apis/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def addcmul(input, tensor1, tensor2, value=1.0):
return legacy.addcmul(input, tensor1, tensor2, mindspore.Tensor(value))

def addmm(input, mat1, mat2, alpha=1.0, beta=1.0):
return add(mul(input, beta), mul(bmm(mat1, mat2), alpha))
return add(mul(input, beta), mul(matmul(mat1, mat2), alpha))

def im2col(input, kernel_size, dilation=1, padding=0, stride=1):
out = legacy.im2_col(input, kernel_size, stride, dilation, padding)
Expand Down
16 changes: 10 additions & 6 deletions mindtorch/_apis/npu.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import mindspore
import mindtorch
from mindspore._c_expression import _empty_instance
from ..configs import use_pyboost, ON_A1, ON_ORANGE_PI
from .._op_prim.ascend import legacy, pyboost
Expand Down Expand Up @@ -824,9 +825,11 @@ def argmax(input, axis, keepdims):
return legacy.argmax(input, axis, keepdims)

def argmin(input, axis, keepdims):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.argmin_ext_op(input, axis, keepdims)
return legacy.argmin(input, axis, keepdims)
if axis is None:
axis = -1
return legacy.arg_min_with_value(input, axis, keepdims)[0]


def bmm(input, other):
Expand Down Expand Up @@ -1136,7 +1139,7 @@ def masked_scatter(input, mask, value):
return legacy.masked_scatter(input, mask, value)

def neg(input):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.neg_op(input)
return legacy.neg(input)

Expand Down Expand Up @@ -1557,7 +1560,7 @@ def inplace_exponential(self, lambd, generator):
return legacy.expo(self, lambd, generator)

def im2col(input, kernel_size, dilation=1, padding=0, stride=1):
if use_pyboost() and not ON_A1:
if use_pyboost() and not ON_A1 and not ON_ORANGE_PI:
return pyboost.im2col_ext_op(input, kernel_size, dilation, padding, stride)
out = legacy.im2_col(input, kernel_size, stride, dilation, padding)
out_shape = out.shape[:1] + (-1,) + out.shape[-1:]
Expand All @@ -1570,9 +1573,10 @@ def upsample_nearest2d(input, output_size, scale_factors):
return legacy.upsample_nearest2d(input, scale_factor, align_corners)

def addmm(input, mat1, mat2, alpha=1.0, beta=1.0):
if use_pyboost():
if use_pyboost() and not ON_ORANGE_PI:
return pyboost.addmm_op(input, mat1, mat2, alpha, beta)
return legacy.addmm(input, mat1, mat2, alpha, beta)
return add(mul(input, beta), mul(matmul(mat1, mat2), alpha))


def meshgrid(input, lambd):
if use_pyboost():
Expand Down
4 changes: 3 additions & 1 deletion mindtorch/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,9 @@ def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
self_viewed = self
self_viewed_shape = list(self.shape)
dim = 0
if ON_ORANGE_PI:
if all([isinstance(index, slice) for index in indexes]):
return getitem(self_viewed, tuple(indexes)), remain_indexes
for i, index in enumerate(indexes):
if isinstance(index, (list, tuple, np.ndarray)):
index_np = np.array(index) if isinstance(index, (list, tuple)) else index
Expand All @@ -634,7 +637,6 @@ def _process_multi_dim_index(self, indexes, remain_indexes, indexed_dims):
raise TypeError(f"Index {index} contain unsupported elements")
self_viewed, dim, remain_indexes, self_viewed_shape = _process_dim_in_multi_dim_index(
self_viewed, self, index, dim, indexed_dims, i, remain_indexes, self_viewed_shape)

return self_viewed, remain_indexes


Expand Down
3 changes: 2 additions & 1 deletion tests/run_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import mindnlp
from mindnlp import transformers

mindspore.set_context(pynative_synchronize=True)
# mindspore.set_context(pynative_synchronize=True)
mindspore.runtime.launch_blocking()

def run_tests():
"""
Expand Down
Loading