|
1 | 1 | # Data Parallel Control (dpctl)
|
2 | 2 | #
|
3 |
| -# Copyright 2020-2022 Intel Corporation |
| 3 | +# Copyright 2020-2023 Intel Corporation |
4 | 4 | #
|
5 | 5 | # Licensed under the Apache License, Version 2.0 (the "License");
|
6 | 6 | # you may not use this file except in compliance with the License.
|
@@ -741,6 +741,119 @@ def finfo(dtype):
|
741 | 741 | return finfo_object(dtype)
|
742 | 742 |
|
743 | 743 |
|
| 744 | +def unstack(X, axis=0): |
| 745 | + """unstack(x, axis=0) |
| 746 | +
|
| 747 | + Splits an array in a sequence of arrays along the given axis. |
| 748 | +
|
| 749 | + Args: |
| 750 | + x (usm_ndarray): input array |
| 751 | +
|
| 752 | + axis (int, optional): axis along which `x` is unstacked. |
| 753 | + If `x` has rank (i.e, number of dimensions) `N`, |
| 754 | + a valid `axis` must reside in the half-open interval `[-N, N)`. |
| 755 | + Default: `0`. |
| 756 | +
|
| 757 | + Returns: |
| 758 | + Tuple[usm_ndarray,...]: A tuple of arrays. |
| 759 | +
|
| 760 | + Raises: |
| 761 | + AxisError: if the `axis` value is invalid. |
| 762 | + """ |
| 763 | + if not isinstance(X, dpt.usm_ndarray): |
| 764 | + raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") |
| 765 | + |
| 766 | + axis = normalize_axis_index(axis, X.ndim) |
| 767 | + Y = dpt.moveaxis(X, axis, 0) |
| 768 | + |
| 769 | + return tuple(Y[i] for i in range(Y.shape[0])) |
| 770 | + |
| 771 | + |
| 772 | +def moveaxis(X, src, dst): |
| 773 | + """moveaxis(x, src, dst) |
| 774 | +
|
| 775 | + Moves axes of an array to new positions. |
| 776 | +
|
| 777 | + Args: |
| 778 | + x (usm_ndarray): input array |
| 779 | +
|
| 780 | + src (int or a sequence of int): |
| 781 | + Original positions of the axes to move. |
| 782 | + These must be unique. If `x` has rank (i.e., number of |
| 783 | + dimensions) `N`, a valid `axis` must be in the |
| 784 | + half-open interval `[-N, N)`. |
| 785 | +
|
| 786 | + dst (int or a sequence of int): |
| 787 | + Destination positions for each of the original axes. |
| 788 | + These must also be unique. If `x` has rank |
| 789 | + (i.e., number of dimensions) `N`, a valid `axis` must be |
| 790 | + in the half-open interval `[-N, N)`. |
| 791 | +
|
| 792 | + Returns: |
| 793 | + usm_narray: Array with moved axes. |
| 794 | + The returned array must has the same data type as `x`, |
| 795 | + is created on the same device as `x` and has the same |
| 796 | + USM allocation type as `x`. |
| 797 | +
|
| 798 | + Raises: |
| 799 | + AxisError: if `axis` value is invalid. |
| 800 | + """ |
| 801 | + if not isinstance(X, dpt.usm_ndarray): |
| 802 | + raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") |
| 803 | + |
| 804 | + if not isinstance(src, (tuple, list)): |
| 805 | + src = (src,) |
| 806 | + |
| 807 | + if not isinstance(dst, (tuple, list)): |
| 808 | + dst = (dst,) |
| 809 | + |
| 810 | + src = normalize_axis_tuple(src, X.ndim, "src") |
| 811 | + dst = normalize_axis_tuple(dst, X.ndim, "dst") |
| 812 | + ind = list(range(0, X.ndim)) |
| 813 | + for i in range(len(src)): |
| 814 | + ind.remove(src[i]) # using the value here which is the same as index |
| 815 | + ind.insert(dst[i], src[i]) |
| 816 | + |
| 817 | + return dpt.permute_dims(X, tuple(ind)) |
| 818 | + |
| 819 | + |
| 820 | +def swapaxes(X, axis1, axis2): |
| 821 | + """swapaxes(x, axis1, axis2) |
| 822 | +
|
| 823 | + Interchanges two axes of an array. |
| 824 | +
|
| 825 | + Args: |
| 826 | + x (usm_ndarray): input array |
| 827 | +
|
| 828 | + axis1 (int): First axis. |
| 829 | + If `x` has rank (i.e., number of dimensions) `N`, |
| 830 | + a valid `axis` must be in the half-open interval `[-N, N)`. |
| 831 | +
|
| 832 | + axis2 (int): Second axis. |
| 833 | + If `x` has rank (i.e., number of dimensions) `N`, |
| 834 | + a valid `axis` must be in the half-open interval `[-N, N)`. |
| 835 | +
|
| 836 | + Returns: |
| 837 | + usm_narray: Array with swapped axes. |
| 838 | + The returned array must has the same data type as `x`, |
| 839 | + is created on the same device as `x` and has the same USM |
| 840 | + allocation type as `x`. |
| 841 | +
|
| 842 | + Raises: |
| 843 | + AxisError: if `axis` value is invalid. |
| 844 | + """ |
| 845 | + if not isinstance(X, dpt.usm_ndarray): |
| 846 | + raise TypeError(f"Expected usm_ndarray type, got {type(X)}.") |
| 847 | + |
| 848 | + axis1 = normalize_axis_index(axis1, X.ndim, "axis1") |
| 849 | + axis2 = normalize_axis_index(axis2, X.ndim, "axis2") |
| 850 | + |
| 851 | + ind = list(range(0, X.ndim)) |
| 852 | + ind[axis1] = axis2 |
| 853 | + ind[axis2] = axis1 |
| 854 | + return dpt.permute_dims(X, tuple(ind)) |
| 855 | + |
| 856 | + |
744 | 857 | def _supported_dtype(dtypes):
|
745 | 858 | for dtype in dtypes:
|
746 | 859 | if dtype.char not in "?bBhHiIlLqQefdFD":
|
|
0 commit comments