Skip to content

Commit 9f09d57

Browse files
committed
fix conv2d on OrangePi
1 parent 5240760 commit 9f09d57

File tree

1 file changed

+37
-4
lines changed

1 file changed

+37
-4
lines changed

mindtorch/_apis/npu.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1025,14 +1025,47 @@ def flatten(input, start_dim, end_dim):
10251025
return legacy.reshape(input, tuple(input_shape))
10261026

10271027
def conv2d_padding(input, weight, bias=None, stride=1, padding='valid', dilation=1, groups=1):
1028-
if use_pyboost():
1028+
if use_pyboost() and not ON_ORANGE_PI:
10291029
return pyboost.conv2d_padding_op(input, weight, bias, stride, padding, dilation, groups)
1030-
return legacy.conv2d(input, weight, bias, stride, padding, dilation, groups)
1030+
return conv2d_legacy(input, weight, bias, stride, padding, dilation, groups)
10311031

10321032
def conv2d(input, weight, bias=None, stride=1, padding='valid', dilation=1, groups=1):
1033-
if use_pyboost():
1033+
if use_pyboost() and not ON_ORANGE_PI:
10341034
return pyboost.conv2d_ext_op(input, weight, bias, stride, padding, dilation, groups)
1035-
return legacy.conv2d(input, weight, bias, stride, padding, dilation, groups)
1035+
return conv2d_legacy(input, weight, bias, stride, padding, dilation, groups)
1036+
1037+
def conv2d_legacy(input, weight, bias=None, stride=1, padding='valid', dilation=1, groups=1):
1038+
pad_mode = 'pad'
1039+
pad = padding
1040+
if isinstance(padding, (tuple, list)):
1041+
pad = (padding[0], padding[0], padding[1], padding[1])
1042+
elif isinstance(padding, int):
1043+
pad = (padding,) * 4
1044+
if not isinstance(padding, (int, tuple, list)):
1045+
pad_mode = padding
1046+
pad = (0,) * 4
1047+
1048+
if isinstance(stride, int):
1049+
stride = (stride,) * 4
1050+
1051+
out_channels = weight.shape[0]
1052+
kernel_size = weight.shape[2:]
1053+
1054+
output = legacy.conv2_d(
1055+
input, weight,
1056+
out_channels,
1057+
kernel_size,
1058+
1,#mode=1,
1059+
pad_mode, #pad_mode=pad_mode,
1060+
pad, #pad=pad,
1061+
tuple(stride), #stride=tuple(stride),
1062+
dilation, #dilation=dilation,
1063+
groups, #group=groups,
1064+
"NCHW", #data_format="NCHW"
1065+
)
1066+
if bias is not None:
1067+
output = legacy.bias_add(output, bias, "NCHW")
1068+
return output
10361069

10371070
def cos(input):
10381071
if use_pyboost():

0 commit comments

Comments
 (0)