|
| 1 | +import operator |
| 2 | + |
| 3 | +import numpy as np |
| 4 | + |
| 5 | +import dpctl.memory as dpm |
| 6 | +import dpctl.tensor as dpt |
| 7 | + |
| 8 | + |
| 9 | +def contract_iter2(shape, strides1, strides2): |
| 10 | + p = np.argsort(np.abs(strides1))[::-1] |
| 11 | + sh = [operator.index(shape[i]) for i in p] |
| 12 | + disp1 = 0 |
| 13 | + disp2 = 0 |
| 14 | + st1 = [] |
| 15 | + st2 = [] |
| 16 | + contractable = True |
| 17 | + for i in p: |
| 18 | + this_stride1 = operator.index(strides1[i]) |
| 19 | + this_stride2 = operator.index(strides2[i]) |
| 20 | + if this_stride1 < 0 and this_stride2 < 0: |
| 21 | + disp1 += this_stride1 * (shape[i] - 1) |
| 22 | + this_stride1 = -this_stride1 |
| 23 | + disp2 += this_stride2 * (shape[i] - 1) |
| 24 | + this_stride2 = -this_stride2 |
| 25 | + if this_stride1 < 0 or this_stride2 < 0: |
| 26 | + contractable = False |
| 27 | + st1.append(this_stride1) |
| 28 | + st2.append(this_stride2) |
| 29 | + while contractable: |
| 30 | + changed = False |
| 31 | + k = len(sh) - 1 |
| 32 | + for i in range(k): |
| 33 | + step1 = st1[i + 1] |
| 34 | + jump1 = st1[i] - (sh[i + 1] - 1) * step1 |
| 35 | + step2 = st2[i + 1] |
| 36 | + jump2 = st2[i] - (sh[i + 1] - 1) * step2 |
| 37 | + if jump1 == step1 and jump2 == step2: |
| 38 | + changed = True |
| 39 | + st1[i:-1] = st1[i + 1 :] |
| 40 | + st2[i:-1] = st2[i + 1 :] |
| 41 | + sh[i] *= sh[i + 1] |
| 42 | + sh[i + 1 : -1] = sh[i + 2 :] |
| 43 | + sh = sh[:-1] |
| 44 | + st1 = st1[:-1] |
| 45 | + st2 = st2[:-1] |
| 46 | + break |
| 47 | + if not changed: |
| 48 | + break |
| 49 | + return (sh, st1, disp1, st2, disp2) |
| 50 | + |
| 51 | + |
| 52 | +def has_memory_overlap(x1, x2): |
| 53 | + m1 = dpm.as_usm_memory(x1) |
| 54 | + m2 = dpm.as_usm_memory(x2) |
| 55 | + if m1.sycl_device == m2.sycl_device: |
| 56 | + p1_beg = m1._pointer |
| 57 | + p1_end = p1_beg + m1.nbytes |
| 58 | + p2_beg = m2._pointer |
| 59 | + p2_end = p2_beg + m2.nbytes |
| 60 | + return p1_beg > p2_end or p2_beg < p1_end |
| 61 | + else: |
| 62 | + return False |
| 63 | + |
| 64 | + |
| 65 | +def copy_to_numpy(ary): |
| 66 | + if type(ary) is not dpt.usm_ndarray: |
| 67 | + raise TypeError |
| 68 | + h = ary.usm_data.copy_to_host().view(ary.dtype) |
| 69 | + itsz = ary.itemsize |
| 70 | + strides_bytes = tuple(si * itsz for si in ary.strides) |
| 71 | + offset = ary.__sycl_usm_array_interface__.get("offset", 0) * itsz |
| 72 | + return np.ndarray( |
| 73 | + ary.shape, |
| 74 | + dtype=ary.dtype, |
| 75 | + buffer=h, |
| 76 | + strides=strides_bytes, |
| 77 | + offset=offset, |
| 78 | + ) |
| 79 | + |
| 80 | + |
| 81 | +def copy_from_numpy(np_ary, usm_type="device", queue=None): |
| 82 | + "Copies numpy array `np_ary` into a new usm_ndarray" |
| 83 | + # This may peform a copy to meet stated requirements |
| 84 | + Xnp = np.require(np_ary, requirements=["A", "O", "C", "E"]) |
| 85 | + if queue: |
| 86 | + ctor_kwargs = {"queue": queue} |
| 87 | + else: |
| 88 | + ctor_kwargs = dict() |
| 89 | + Xusm = dpt.usm_ndarray( |
| 90 | + Xnp.shape, |
| 91 | + dtype=Xnp.dtype, |
| 92 | + buffer=usm_type, |
| 93 | + buffer_ctor_kwargs=ctor_kwargs, |
| 94 | + ) |
| 95 | + Xusm.usm_data.copy_from_host(Xnp.reshape((-1)).view("u1")) |
| 96 | + return Xusm |
| 97 | + |
| 98 | + |
| 99 | +def copy_from_numpy_into(dst, np_ary): |
| 100 | + if not isinstance(np_ary, np.ndarray): |
| 101 | + raise TypeError("Expected numpy.ndarray, got {}".format(type(np_ary))) |
| 102 | + src_ary = np.broadcast_to(np.asarray(np_ary, dtype=dst.dtype), dst.shape) |
| 103 | + for i in range(dst.size): |
| 104 | + mi = np.unravel_index(i, dst.shape) |
| 105 | + host_buf = np.array(src_ary[mi], ndmin=1).view("u1") |
| 106 | + usm_mem = dpm.as_usm_memory(dst[mi]) |
| 107 | + usm_mem.copy_from_host(host_buf) |
| 108 | + |
| 109 | + |
| 110 | +class Dummy: |
| 111 | + def __init__(self, iface): |
| 112 | + self.__sycl_usm_array_interface__ = iface |
| 113 | + |
| 114 | + |
| 115 | +def copy_same_dtype(dst, src): |
| 116 | + if type(dst) is not dpt.usm_ndarray or type(src) is not dpt.usm_ndarray: |
| 117 | + raise TypeError |
| 118 | + |
| 119 | + if dst.shape != src.shape: |
| 120 | + raise ValueError |
| 121 | + |
| 122 | + if dst.dtype != src.dtype: |
| 123 | + raise ValueError |
| 124 | + |
| 125 | + # check that memory regions do not overlap |
| 126 | + if has_memory_overlap(dst, src): |
| 127 | + tmp = copy_to_numpy(src) |
| 128 | + copy_from_numpy_into(dst, tmp) |
| 129 | + return |
| 130 | + |
| 131 | + if (dst.flags & 1) and (src.flags & 1): |
| 132 | + dst_mem = dpm.as_usm_memory(dst) |
| 133 | + src_mem = dpm.as_usm_memory(src) |
| 134 | + dst_mem.copy_from_device(src_mem) |
| 135 | + return |
| 136 | + |
| 137 | + # simplify strides |
| 138 | + sh_i, dst_st, dst_disp, src_st, src_disp = contract_iter2( |
| 139 | + dst.shape, dst.strides, src.strides |
| 140 | + ) |
| 141 | + # sh_i, dst_st, dst_disp, src_st, src_disp = ( |
| 142 | + # dst.shape, dst.strides, 0, src.strides, 0 |
| 143 | + # ) |
| 144 | + src_iface = src.__sycl_usm_array_interface__ |
| 145 | + dst_iface = dst.__sycl_usm_array_interface__ |
| 146 | + src_iface["shape"] = tuple() |
| 147 | + src_iface.pop("strides", None) |
| 148 | + dst_iface["shape"] = tuple() |
| 149 | + dst_iface.pop("strides", None) |
| 150 | + dst_disp = dst_disp + dst_iface.get("offset", 0) |
| 151 | + src_disp = src_disp + src_iface.get("offset", 0) |
| 152 | + for i in range(dst.size): |
| 153 | + mi = np.unravel_index(i, sh_i) |
| 154 | + dst_offset = dst_disp |
| 155 | + src_offset = src_disp |
| 156 | + for j, dst_stj, src_stj in zip(mi, dst_st, src_st): |
| 157 | + dst_offset = dst_offset + j * dst_stj |
| 158 | + src_offset = src_offset + j * src_stj |
| 159 | + dst_iface["offset"] = dst_offset |
| 160 | + src_iface["offset"] = src_offset |
| 161 | + msrc = dpm.as_usm_memory(Dummy(src_iface)) |
| 162 | + mdst = dpm.as_usm_memory(Dummy(dst_iface)) |
| 163 | + mdst.copy_from_device(msrc) |
| 164 | + |
| 165 | + |
| 166 | +def copy_same_shape(dst, src): |
| 167 | + if src.dtype == dst.dtype: |
| 168 | + copy_same_dtype(dst, src) |
| 169 | + |
| 170 | + # check that memory regions do not overlap |
| 171 | + if has_memory_overlap(dst, src): |
| 172 | + tmp = copy_to_numpy(src) |
| 173 | + tmp = tmp.astype(dst.dtype) |
| 174 | + copy_from_numpy_into(dst, tmp) |
| 175 | + return |
| 176 | + |
| 177 | + # simplify strides |
| 178 | + sh_i, dst_st, dst_disp, src_st, src_disp = contract_iter2( |
| 179 | + dst.shape, dst.strides, src.strides |
| 180 | + ) |
| 181 | + # sh_i, dst_st, dst_disp, src_st, src_disp = ( |
| 182 | + # dst.shape, dst.strides, 0, src.strides, 0 |
| 183 | + # ) |
| 184 | + src_iface = src.__sycl_usm_array_interface__ |
| 185 | + dst_iface = dst.__sycl_usm_array_interface__ |
| 186 | + src_iface["shape"] = tuple() |
| 187 | + src_iface.pop("strides", None) |
| 188 | + dst_iface["shape"] = tuple() |
| 189 | + dst_iface.pop("strides", None) |
| 190 | + dst_disp = dst_disp + dst_iface.get("offset", 0) |
| 191 | + src_disp = src_disp + src_iface.get("offset", 0) |
| 192 | + for i in range(dst.size): |
| 193 | + mi = np.unravel_index(i, sh_i) |
| 194 | + dst_offset = dst_disp |
| 195 | + src_offset = src_disp |
| 196 | + for j, dst_stj, src_stj in zip(mi, dst_st, src_st): |
| 197 | + dst_offset = dst_offset + j * dst_stj |
| 198 | + src_offset = src_offset + j * src_stj |
| 199 | + dst_iface["offset"] = dst_offset |
| 200 | + src_iface["offset"] = src_offset |
| 201 | + msrc = dpm.as_usm_memory(Dummy(src_iface)) |
| 202 | + mdst = dpm.as_usm_memory(Dummy(dst_iface)) |
| 203 | + tmp = msrc.copy_to_host().view(src.dtype) |
| 204 | + tmp = tmp.astype(dst.dtype) |
| 205 | + mdst.copy_from_host(tmp.view("u1")) |
| 206 | + |
| 207 | + |
| 208 | +def copy_from_usm_ndarray_to_usm_ndarray(dst, src): |
| 209 | + if type(dst) is not dpt.usm_ndarray or type(src) is not dpt.usm_ndarray: |
| 210 | + raise TypeError |
| 211 | + |
| 212 | + if dst.ndim == src.ndim and dst.shape == src.shape: |
| 213 | + copy_same_shape(dst, src) |
| 214 | + |
| 215 | + try: |
| 216 | + common_shape = np.broadcast_shapes(dst.shape, src.shape) |
| 217 | + except ValueError: |
| 218 | + raise ValueError |
| 219 | + |
| 220 | + if dst.size < src.size: |
| 221 | + raise ValueError |
| 222 | + |
| 223 | + if len(common_shape) > dst.ndim: |
| 224 | + ones_count = len(common_shape) - dst.ndim |
| 225 | + for k in range(ones_count): |
| 226 | + if common_shape[k] != 1: |
| 227 | + raise ValueError |
| 228 | + common_shape = common_shape[ones_count:] |
| 229 | + |
| 230 | + if src.ndim < len(common_shape): |
| 231 | + new_src_strides = (0,) * (len(common_shape) - src.ndim) + src.strides |
| 232 | + src_same_shape = dpt.usm_ndarray( |
| 233 | + common_shape, dtype=src.dtype, buffer=src, strides=new_src_strides |
| 234 | + ) |
| 235 | + else: |
| 236 | + src_same_shape = src |
| 237 | + |
| 238 | + copy_same_shape(dst, src_same_shape) |
| 239 | + |
| 240 | + |
| 241 | +def astype(usm_ary, newdtype, order="K", casting="unsafe", copy=True): |
| 242 | + """ |
| 243 | + astype(usm_array, new_dtype, order="K", casting="unsafe", copy=True) |
| 244 | +
|
| 245 | + Returns a copy of the array, cast to a specified type. |
| 246 | +
|
| 247 | + A view can be returned, if possible, when `copy=False` is used. |
| 248 | + """ |
| 249 | + if not isinstance(usm_ary, dpt.usm_ndarray): |
| 250 | + return TypeError( |
| 251 | + "Expected object of type dpt.usm_ndarray, got {}".format( |
| 252 | + type(usm_ary) |
| 253 | + ) |
| 254 | + ) |
| 255 | + ary_dtype = usm_ary.dtype |
| 256 | + target_dtype = np.dtype(newdtype) |
| 257 | + if not np.can_cast(ary_dtype, target_dtype, casting=casting): |
| 258 | + raise TypeError( |
| 259 | + "Can not cast from {} to {} according to rule {}".format( |
| 260 | + ary_dtype, newdtype, casting |
| 261 | + ) |
| 262 | + ) |
| 263 | + c_contig = usm_ary.flags & 1 |
| 264 | + f_contig = usm_ary.flags & 2 |
| 265 | + needs_copy = copy or not (ary_dtype == target_dtype) |
| 266 | + if not needs_copy and (order != "K"): |
| 267 | + needs_copy = (c_contig and order not in ["A", "C"]) or ( |
| 268 | + f_contig and order not in ["A", "F"] |
| 269 | + ) |
| 270 | + if needs_copy: |
| 271 | + copy_order = "C" |
| 272 | + if order == "C": |
| 273 | + pass |
| 274 | + elif order == "F": |
| 275 | + copy_order = order |
| 276 | + elif order == "A": |
| 277 | + if usm_ary.flags & 2: |
| 278 | + copy_order = "F" |
| 279 | + elif order == "K": |
| 280 | + if usm_ary.flags & 2: |
| 281 | + copy_order = "F" |
| 282 | + R = dpt.usm_ndarray( |
| 283 | + usm_ary.shape, |
| 284 | + dtype=target_dtype, |
| 285 | + buffer=usm_ary.usm_type, |
| 286 | + order=copy_order, |
| 287 | + buffer_ctor_kwargs={"queue": usm_ary.sycl_queue}, |
| 288 | + ) |
| 289 | + if order == "K" and (not c_contig and not f_contig): |
| 290 | + original_strides = usm_ary.strides |
| 291 | + ind = sorted( |
| 292 | + range(usm_ary.ndim), |
| 293 | + key=lambda i: abs(original_strides[i]), |
| 294 | + reverse=True, |
| 295 | + ) |
| 296 | + new_strides = tuple(R.strides[ind[i]] for i in ind) |
| 297 | + R = dpt.usm_ndarray( |
| 298 | + usm_ary.shape, |
| 299 | + dtype=target_dtype, |
| 300 | + buffer=R.usm_data, |
| 301 | + strides=new_strides, |
| 302 | + ) |
| 303 | + copy_from_usm_ndarray_to_usm_ndarray(R, usm_ary) |
| 304 | + return R |
| 305 | + else: |
| 306 | + return usm_ary |
0 commit comments