Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Adapt to changes in cudf.core.buffer.Buffer #5154

Merged
merged 11 commits into from
Jan 26, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def subtract_valid(input_array, valid_bool_array, sub_val):
input_array[pos] = input_array[pos] - sub_val


@cudf.core.buffer.acquire_spill_lock()
def get_stem_series(word_str_ser, suffix_len, can_replace_mask):
"""
word_str_ser: input string column
Expand All @@ -95,8 +96,8 @@ def get_stem_series(word_str_ser, suffix_len, can_replace_mask):
start_series = cudf.Series(cp.zeros(len(word_str_ser), dtype=cp.int32))
end_ser = word_str_ser.str.len()

end_ar = end_ser._column.data_array_view
can_replace_mask_ar = can_replace_mask._column.data_array_view
end_ar = end_ser._column.data_array_view(mode="read")
can_replace_mask_ar = can_replace_mask._column.data_array_view(mode="read")

subtract_valid[NBLCK, NTHRD](end_ar, can_replace_mask_ar, suffix_len)
return word_str_ser.str.slice_from(
Expand Down
6 changes: 5 additions & 1 deletion python/cuml/tests/test_input_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,11 @@ def check_numpy_order(ary, order):
def check_ptr(a, b, input_type):
if input_type == 'cudf':
for (_, col_a), (_, col_b) in zip(a._data.items(), b._data.items()):
assert col_a.base_data.ptr == col_b.base_data.ptr
with cudf.core.buffer.acquire_spill_lock():
assert (
col_a.base_data.get_ptr(mode="read") ==
col_b.base_data.get_ptr(mode="read")
)
else:
def get_ptr(x):
try:
Expand Down