Skip to content

Commit

Permalink
Merge pull request #1167 from IntelPython/fix-take
Browse files Browse the repository at this point in the history
Fix take
  • Loading branch information
oleksandr-pavlyk authored Aug 18, 2022
2 parents 51369c3 + d839ea1 commit cfc9b40
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 38 deletions.
47 changes: 29 additions & 18 deletions dpnp/backend/kernels/dpnp_krnl_arraycreation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,9 @@ DPCTLSyclEventRef dpnp_ptp_c(DPCTLSyclQueueRef q_ref,
(void)dep_event_vec_ref;

DPCTLSyclEventRef event_ref = nullptr;
DPCTLSyclEventRef e1_ref = nullptr;
DPCTLSyclEventRef e2_ref = nullptr;
DPCTLSyclEventRef e3_ref = nullptr;

if ((input1_in == nullptr) || (result1_out == nullptr))
{
Expand All @@ -514,29 +517,36 @@ DPCTLSyclEventRef dpnp_ptp_c(DPCTLSyclQueueRef q_ref,
_DataType* min_arr = reinterpret_cast<_DataType*>(sycl::malloc_shared(result_size * sizeof(_DataType), q));
_DataType* max_arr = reinterpret_cast<_DataType*>(sycl::malloc_shared(result_size * sizeof(_DataType), q));

dpnp_min_c<_DataType>(arr, min_arr, result_size, input_shape, input_ndim, axis, naxis);
dpnp_max_c<_DataType>(arr, max_arr, result_size, input_shape, input_ndim, axis, naxis);
e1_ref = dpnp_min_c<_DataType>(q_ref, arr, min_arr, result_size, input_shape, input_ndim, axis, naxis, NULL);
e2_ref = dpnp_max_c<_DataType>(q_ref, arr, max_arr, result_size, input_shape, input_ndim, axis, naxis, NULL);

shape_elem_type* _strides =
reinterpret_cast<shape_elem_type*>(sycl::malloc_shared(result_ndim * sizeof(shape_elem_type), q));
get_shape_offsets_inkernel(result_shape, result_ndim, _strides);

dpnp_subtract_c<_DataType, _DataType, _DataType>(result,
result_size,
result_ndim,
result_shape,
result_strides,
max_arr,
result_size,
result_ndim,
result_shape,
_strides,
min_arr,
result_size,
result_ndim,
result_shape,
_strides,
NULL);
e3_ref = dpnp_subtract_c<_DataType, _DataType, _DataType>(q_ref, result,
result_size,
result_ndim,
result_shape,
result_strides,
max_arr,
result_size,
result_ndim,
result_shape,
_strides,
min_arr,
result_size,
result_ndim,
result_shape,
_strides,
NULL, NULL);

DPCTLEvent_Wait(e1_ref);
DPCTLEvent_Wait(e2_ref);
DPCTLEvent_Wait(e3_ref);
DPCTLEvent_Delete(e1_ref);
DPCTLEvent_Delete(e2_ref);
DPCTLEvent_Delete(e3_ref);

sycl::free(min_arr, q);
sycl::free(max_arr, q);
Expand Down Expand Up @@ -576,6 +586,7 @@ void dpnp_ptp_c(void* result1_out,
naxis,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType>
Expand Down
24 changes: 16 additions & 8 deletions dpnp/backend/kernels/dpnp_krnl_bitwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,16 +148,16 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
\
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref)); \
\
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, input1_in, input1_size); \
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(q_ref, input1_shape, input1_ndim, true); \
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(q_ref, input1_strides, input1_ndim, true); \
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, input1_in, input1_size); \
DPNPC_ptr_adapter<shape_elem_type> input1_shape_ptr(q_ref, input1_shape, input1_ndim, true); \
DPNPC_ptr_adapter<shape_elem_type> input1_strides_ptr(q_ref, input1_strides, input1_ndim, true); \
\
DPNPC_ptr_adapter<_DataType> input2_ptr(q_ref, input2_in, input2_size); \
DPNPC_ptr_adapter<shape_elem_type> input2_shape_ptr(q_ref, input2_shape, input2_ndim, true); \
DPNPC_ptr_adapter<shape_elem_type> input2_strides_ptr(q_ref, input2_strides, input2_ndim, true); \
DPNPC_ptr_adapter<_DataType> input2_ptr(q_ref, input2_in, input2_size); \
DPNPC_ptr_adapter<shape_elem_type> input2_shape_ptr(q_ref, input2_shape, input2_ndim, true); \
DPNPC_ptr_adapter<shape_elem_type> input2_strides_ptr(q_ref, input2_strides, input2_ndim, true); \
\
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref, result_out, result_size, false, true); \
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(q_ref, result_strides, result_ndim); \
DPNPC_ptr_adapter<_DataType> result_ptr(q_ref, result_out, result_size, false, true); \
DPNPC_ptr_adapter<shape_elem_type> result_strides_ptr(q_ref, result_strides, result_ndim); \
\
_DataType* input1_data = input1_ptr.get_ptr(); \
shape_elem_type* input1_shape_data = input1_shape_ptr.get_ptr(); \
Expand Down Expand Up @@ -226,6 +226,14 @@ static void func_map_init_bitwise_1arg_1type(func_map_t& fmap)
}; \
event = q.submit(kernel_func); \
} \
input1_ptr.depends_on(event); \
input1_shape_ptr.depends_on(event); \
input1_strides_ptr.depends_on(event); \
input2_ptr.depends_on(event); \
input2_shape_ptr.depends_on(event); \
input2_strides_ptr.depends_on(event); \
result_ptr.depends_on(event); \
result_strides_ptr.depends_on(event); \
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); \
\
return DPCTLEvent_Copy(event_ref); \
Expand Down
23 changes: 23 additions & 0 deletions dpnp/backend/kernels/dpnp_krnl_elemwise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@
} \
} \
\
input1_ptr.depends_on(event); \
input1_shape_ptr.depends_on(event); \
input1_strides_ptr.depends_on(event); \
result_ptr.depends_on(event); \
result_strides_ptr.depends_on(event); \
\
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); \
\
return DPCTLEvent_Copy(event_ref); \
Expand Down Expand Up @@ -644,6 +650,12 @@ static void func_map_init_elemwise_1arg_2type(func_map_t& fmap)
} \
} \
\
input1_ptr.depends_on(event); \
input1_shape_ptr.depends_on(event); \
input1_strides_ptr.depends_on(event); \
result_ptr.depends_on(event); \
result_strides_ptr.depends_on(event); \
\
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); \
\
return DPCTLEvent_Copy(event_ref); \
Expand Down Expand Up @@ -998,6 +1010,17 @@ static void func_map_init_elemwise_1arg_1type(func_map_t& fmap)
event = q.submit(kernel_func); \
} \
} \
\
input1_ptr.depends_on(event); \
input1_shape_ptr.depends_on(event); \
input1_strides_ptr.depends_on(event); \
input2_ptr.depends_on(event); \
input2_shape_ptr.depends_on(event); \
input2_strides_ptr.depends_on(event); \
result_ptr.depends_on(event); \
result_shape_ptr.depends_on(event); \
result_strides_ptr.depends_on(event); \
\
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event); \
\
return DPCTLEvent_Copy(event_ref); \
Expand Down
2 changes: 1 addition & 1 deletion dpnp/backend/kernels/dpnp_krnl_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -901,7 +901,7 @@ DPCTLSyclEventRef dpnp_take_c(DPCTLSyclQueueRef q_ref,
DPCTLSyclEventRef event_ref = nullptr;
sycl::queue q = *(reinterpret_cast<sycl::queue*>(q_ref));

DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, array1_size, true);
DPNPC_ptr_adapter<_DataType> input1_ptr(q_ref, array1_in, array1_size);
DPNPC_ptr_adapter<_IndecesType> input2_ptr(q_ref, indices1, size);
_DataType* array_1 = input1_ptr.get_ptr();
_IndecesType* indices = input2_ptr.get_ptr();
Expand Down
8 changes: 8 additions & 0 deletions dpnp/backend/kernels/dpnp_krnl_mathematical.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ DPCTLSyclEventRef dpnp_elemwise_absolute_c(DPCTLSyclQueueRef q_ref,
event = q.submit(kernel_func);
}

input1_ptr.depends_on(event);
result1_ptr.depends_on(event);
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);

return DPCTLEvent_Copy(event_ref);
Expand Down Expand Up @@ -483,6 +485,8 @@ DPCTLSyclEventRef dpnp_ediff1d_c(DPCTLSyclQueueRef q_ref,
};
event = q.submit(kernel_func);

input1_ptr.depends_on(event);
result_ptr.depends_on(event);
event_ref = reinterpret_cast<DPCTLSyclEventRef>(&event);

return DPCTLEvent_Copy(event_ref);
Expand Down Expand Up @@ -676,6 +680,7 @@ void dpnp_floor_divide_c(void* result_out,
where,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
Expand Down Expand Up @@ -770,6 +775,7 @@ void dpnp_modf_c(void* array1_in, void* result1_out, void* result2_out, size_t s
size,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType_input, typename _DataType_output>
Expand Down Expand Up @@ -911,6 +917,7 @@ void dpnp_remainder_c(void* result_out,
where,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType_output, typename _DataType_input1, typename _DataType_input2>
Expand Down Expand Up @@ -1041,6 +1048,7 @@ void dpnp_trapz_c(
array2_size,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType_input1, typename _DataType_input2, typename _DataType_output>
Expand Down
2 changes: 2 additions & 0 deletions dpnp/backend/kernels/dpnp_krnl_reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ void dpnp_sum_c(void* result_out,
where,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType_output, typename _DataType_input>
Expand Down Expand Up @@ -278,6 +279,7 @@ void dpnp_prod_c(void* result_out,
where,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType_output, typename _DataType_input>
Expand Down
4 changes: 4 additions & 0 deletions dpnp/backend/kernels/dpnp_krnl_sorting.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ void dpnp_argsort_c(void* array1_in, void* result1, size_t size)
size,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType, typename _idx_DataType>
Expand Down Expand Up @@ -242,6 +243,7 @@ void dpnp_partition_c(
ndim,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType>
Expand Down Expand Up @@ -394,6 +396,7 @@ void dpnp_searchsorted_c(
v_size,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType, typename _IndexingType>
Expand Down Expand Up @@ -459,6 +462,7 @@ void dpnp_sort_c(void* array1_in, void* result1, size_t size)
size,
dep_event_vec_ref);
DPCTLEvent_WaitAndThrow(event_ref);
DPCTLEvent_Delete(event_ref);
}

template <typename _DataType>
Expand Down
Loading

0 comments on commit cfc9b40

Please sign in to comment.