Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions vortex-gpu/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ vortex-dict = { workspace = true }
vortex-dtype = { workspace = true }
vortex-error = { workspace = true }
vortex-fastlanes = { workspace = true }
vortex-mask = { workspace = true }
vortex-utils = { workspace = true }

[build-dependencies]
Expand Down
41 changes: 37 additions & 4 deletions vortex-gpu/kernels/dict_take.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
#include <cuda_runtime.h>
#include <stdint.h>


template<typename CodeT, typename ValueT>
__device__ void dict_take_values(
__device__ void dict_take(
const CodeT *__restrict codes_array,
const ValueT *__restrict values,
ValueT *__restrict values_out
Expand All @@ -26,14 +25,48 @@ __device__ void dict_take_values(
}
}

template<typename CodeT, typename ValueT>
__device__ void dict_take_masked(
const CodeT *__restrict codes_array,
const ValueT *__restrict values,
const uint32_t *__restrict mask_array,
ValueT *__restrict values_out
) {
auto i = threadIdx.x;
auto block_offset = (blockIdx.x * 1024);
auto mask_block_offset = (blockIdx.x * (1024 / 32));

auto codes = codes_array + block_offset;
auto mask = mask_array + mask_block_offset;
auto out = values_out + block_offset;

const int thread_ops = 32;

for (auto j = 0; j < thread_ops; j++) {
if (mask[i] >> j & 1) {
auto idx = i * thread_ops + j;
auto code = codes[idx];
out[idx] = values[code];
}
}
}

// Macro to generate the extern "C" wrapper for each type combination
#define GENERATE_KERNEL(code_suffix, value_suffix, CodeType, ValueType) \
extern "C" __global__ void dict_take_c##code_suffix##_v##value_suffix##_values( \
extern "C" __global__ void dict_take_c##code_suffix##_v##value_suffix( \
const CodeType *__restrict codes_array, \
const ValueType *__restrict values, \
ValueType *__restrict values_out \
) { \
dict_take(codes_array, values, values_out); \
} \
extern "C" __global__ void dict_take_masked_c##code_suffix##_v##value_suffix( \
const CodeType *__restrict codes_array, \
const ValueType *__restrict values, \
const uint32_t *__restrict mask, \
ValueType *__restrict values_out \
) { \
dict_take_values<CodeType, ValueType>(codes_array, values, values_out); \
dict_take_masked(codes_array, values, mask, values_out); \
}

// Generate all combinations
Expand Down
83 changes: 73 additions & 10 deletions vortex-gpu/src/take.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
// SPDX-License-Identifier: Apache-2.0
// SPDX-FileCopyrightText: Copyright the Vortex contributors

use std::mem::transmute;
use std::sync::Arc;

use cudarc::driver::{CudaContext, CudaFunction, DeviceRepr, LaunchConfig, PushKernelArg};
use cudarc::driver::{
CudaContext, CudaFunction, DeviceRepr, LaunchConfig, PushKernelArg, ValidAsZeroBits,
};
use cudarc::nvrtc::Ptx;
use vortex_array::arrays::PrimitiveArray;
use vortex_array::validity::Validity;
Expand All @@ -15,9 +18,18 @@ use vortex_dtype::{
match_each_unsigned_integer_ptype,
};
use vortex_error::{VortexExpect, VortexResult, vortex_err};
use vortex_mask::Mask;

// For now we only support integer non-nullable codes and values.
pub fn cuda_take(dict: &DictArray, ctx: Arc<CudaContext>) -> VortexResult<Option<ArrayRef>> {
cuda_take_masked(dict, None, ctx)
}

// For now we only support integer non-nullable codes and values.
pub fn cuda_take_masked(
dict: &DictArray,
mask: Option<Mask>,
ctx: Arc<CudaContext>,
) -> VortexResult<Option<ArrayRef>> {
if !matches!(dict.dtype(), DType::Primitive(_, Nullability::NonNullable)) {
return Ok(None);
};
Expand All @@ -31,7 +43,7 @@ pub fn cuda_take(dict: &DictArray, ctx: Arc<CudaContext>) -> VortexResult<Option

let result = match_each_native_ptype!(values.ptype(), |V| {
match_each_unsigned_integer_ptype!(codes.ptype(), |C| {
cuda_take_impl::<C, V>(codes, values, ctx)
cuda_take_impl::<C, V>(codes, values, mask, ctx)
})
});
result.map(Some)
Expand All @@ -40,19 +52,20 @@ pub fn cuda_take(dict: &DictArray, ctx: Arc<CudaContext>) -> VortexResult<Option
fn cuda_take_impl<Codes, Values>(
codes: PrimitiveArray,
values: PrimitiveArray,
mask: Option<Mask>,
ctx: Arc<CudaContext>,
) -> VortexResult<ArrayRef>
where
Codes: UnsignedPType + DeviceRepr,
Values: NativePType + DeviceRepr,
Values: NativePType + DeviceRepr + ValidAsZeroBits,
{
let values_sl = values.as_slice::<Values>();
let codes_sl = codes.as_slice::<Codes>();

assert!(values.len() <= 1024);
assert_eq!(codes.len() % 1024, 0);

let kernel_func = cuda_take_kernel::<Codes, Values>(ctx.clone())?;
let kernel_func = cuda_take_kernel::<Codes, Values>(mask.is_some(), ctx.clone())?;
let num_chunks = u32::try_from(codes.len().div_ceil(1024)).vortex_expect("num chunks overflow");
let stream = ctx.default_stream();

Expand All @@ -62,15 +75,33 @@ where
let cu_codes = stream
.memcpy_stod(codes_sl)
.map_err(|e| vortex_err!("Failed to copy to device: {e}"))?;
let mut cu_out = unsafe {
let mut cu_out = {
// TODO(joe): use uninit memory
stream
.alloc::<Values>(codes.len().next_multiple_of(1024))
.alloc_zeros::<Values>(codes.len().next_multiple_of(1024))
.map_err(|e| vortex_err!("Failed to allocate stream: {e}"))?
};

let cu_mask = mask
.map(|mask| {
let buffer = mask.to_boolean_buffer();
assert_eq!(buffer.offset(), 0);
assert_eq!(buffer.len() % 1024, 0);
assert!((buffer.values().as_ptr() as *const u32).is_aligned());
// SAFETY: we've checked alignment and the layout is the same.
let slice: &[u32] = unsafe { transmute(buffer.values()) };
stream
.memcpy_stod(slice)
.map_err(|e| vortex_err!("Failed to copy to device: {e}"))
})
.transpose()?;

let mut launch = stream.launch_builder(&kernel_func);
launch.arg(&cu_codes);
launch.arg(&cu_values);
if let Some(cu_mask) = cu_mask.as_ref() {
launch.arg(cu_mask);
}
launch.arg(&mut cu_out);
unsafe {
launch.launch(LaunchConfig {
Expand All @@ -94,7 +125,7 @@ where
Ok(PrimitiveArray::new(buffer, Validity::NonNullable).into_array())
}

fn cuda_take_kernel<Codes, Values>(ctx: Arc<CudaContext>) -> VortexResult<CudaFunction>
fn cuda_take_kernel<Codes, Values>(mask: bool, ctx: Arc<CudaContext>) -> VortexResult<CudaFunction>
where
Codes: NativePType,
Values: NativePType,
Expand All @@ -103,7 +134,12 @@ where
.load_module(Ptx::from_file("kernels/dict_take.ptx"))
.map_err(|e| vortex_err!("Failed to load kernel module: {e}"))?;

let kernel_name = format!("dict_take_c{}_v{}_values", &Codes::PTYPE, &Values::PTYPE);
let kernel_name = format!(
"dict_take{}_c{}_v{}",
if mask { "_masked" } else { "" },
&Codes::PTYPE,
&Values::PTYPE
);

let kernel_func = module
.load_function(&kernel_name)
Expand All @@ -119,8 +155,9 @@ mod tests {
use vortex_array::{IntoArray, ToCanonical};
use vortex_dict::DictArray;
use vortex_error::VortexExpect;
use vortex_mask::Mask;

use crate::take::cuda_take;
use crate::take::{cuda_take, cuda_take_masked};

#[test]
fn test_cuda_take_u32_u32() {
Expand Down Expand Up @@ -173,4 +210,30 @@ mod tests {

assert_eq!(result.as_slice::<i64>(), expect.as_slice::<i64>());
}

#[test]
fn test_cuda_take_masked() {
const LEN: u64 = 1024 * 8;
let values: PrimitiveArray = (0u64..1024).map(|x| (x + 2) % 1024).collect();
let codes: PrimitiveArray = (0u64..LEN).map(|x| (x + 1) % 1024).collect();
let dict = DictArray::try_new(codes.into_array(), values.into_array()).unwrap();

let expect = dict.to_primitive();

let mask = Mask::from_iter((0..LEN).map(|i| (i % 4) == 0));

let ctx = CudaContext::new(0).unwrap();
ctx.set_blocking_synchronize().unwrap();
let result = cuda_take_masked(&dict, Some(mask.clone()), ctx)
.unwrap()
.unwrap()
.to_primitive();

let result_sl = result.as_slice::<u64>();
let expect_sl = expect.as_slice::<u64>();

mask.to_boolean_buffer().set_indices().for_each(|i| {
assert_eq!(result_sl[i], expect_sl[i]);
})
}
}
Loading