@@ -1150,6 +1150,77 @@ def rank_2d(rank_t[:, :] in_arr, axis=0, ties_method='average',
11501150 return ranks
11511151
11521152
1153+ ctypedef fused diff_t:
1154+ float64_t
1155+ float32_t
1156+ int8_t
1157+ int16_t
1158+ int32_t
1159+ int64_t
1160+
1161+ ctypedef fused out_t:
1162+ float32_t
1163+ float64_t
1164+
1165+
1166+ @ cython.boundscheck (False )
1167+ @ cython.wraparound (False )
1168+ def diff_2d (ndarray[diff_t , ndim = 2 ] arr,
1169+ ndarray[out_t , ndim = 2 ] out,
1170+ Py_ssize_t periods , int axis ):
1171+ cdef:
1172+ Py_ssize_t i, j, sx, sy, start, stop
1173+ bint f_contig = arr.flags.f_contiguous
1174+
1175+ # Disable for unsupported dtype combinations,
1176+ # see https://github.com/cython/cython/issues/2646
1177+ if (out_t is float32_t
1178+ and not (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t)):
1179+ raise NotImplementedError
1180+ elif (out_t is float64_t
1181+ and (diff_t is float32_t or diff_t is int8_t or diff_t is int16_t)):
1182+ raise NotImplementedError
1183+ else :
1184+ # We put this inside an indented else block to avoid cython build
1185+ # warnings about unreachable code
1186+ sx, sy = (< object > arr).shape
1187+ with nogil:
1188+ if f_contig:
1189+ if axis == 0 :
1190+ if periods >= 0 :
1191+ start, stop = periods, sx
1192+ else :
1193+ start, stop = 0 , sx + periods
1194+ for j in range (sy):
1195+ for i in range (start, stop):
1196+ out[i, j] = arr[i, j] - arr[i - periods, j]
1197+ else :
1198+ if periods >= 0 :
1199+ start, stop = periods, sy
1200+ else :
1201+ start, stop = 0 , sy + periods
1202+ for j in range (start, stop):
1203+ for i in range (sx):
1204+ out[i, j] = arr[i, j] - arr[i, j - periods]
1205+ else :
1206+ if axis == 0 :
1207+ if periods >= 0 :
1208+ start, stop = periods, sx
1209+ else :
1210+ start, stop = 0 , sx + periods
1211+ for i in range (start, stop):
1212+ for j in range (sy):
1213+ out[i, j] = arr[i, j] - arr[i - periods, j]
1214+ else :
1215+ if periods >= 0 :
1216+ start, stop = periods, sy
1217+ else :
1218+ start, stop = 0 , sy + periods
1219+ for i in range (sx):
1220+ for j in range (start, stop):
1221+ out[i, j] = arr[i, j] - arr[i, j - periods]
1222+
1223+
11531224# generated from template
11541225include " algos_common_helper.pxi"
11551226include " algos_take_helper.pxi"
0 commit comments