Skip to content

Shader Execution Reordering #145

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions include/drjit-core/jit.h
Original file line number Diff line number Diff line change
Expand Up @@ -1581,14 +1581,17 @@ enum class JitFlag : uint32_t {
/// Set to \c true when traversing inputs or outputs of a frozen function
EnableObjectTraversal = 1 << 22,

/// Reorder threads in OptiX after a ray-intersection
ShaderExecutionReordering = 1 << 23,

/// Default flags
Default = (uint32_t) ConstantPropagation | (uint32_t) ValueNumbering |
(uint32_t) FastMath | (uint32_t) SymbolicLoops |
(uint32_t) OptimizeLoops | (uint32_t) SymbolicCalls |
(uint32_t) MergeFunctions | (uint32_t) OptimizeCalls |
(uint32_t) SymbolicConditionals | (uint32_t) ReuseIndices |
(uint32_t) ScatterReduceLocal | (uint32_t) PacketOps |
(uint32_t) KernelFreezing,
(uint32_t) KernelFreezing | (uint32_t) ShaderExecutionReordering,

// Deprecated aliases, will be removed in a future version of Dr.Jit
LoopRecord = SymbolicLoops,
Expand Down Expand Up @@ -1622,7 +1625,8 @@ enum JitFlag {
JitFlagSymbolic = 1 << 19,
KernelFreezing = 1 << 20,
FreezingScope = 1 << 21,
EnableObjectTraversal = 1 << 22
EnableObjectTraversal = 1 << 22,
JitFlagShaderExecutionReordering = 1 << 23
};
#endif

Expand Down
39 changes: 39 additions & 0 deletions include/drjit-core/optix.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ enum OptixHitObjectField {
* and \c hit_object_fields. The results will be stored in new variables whose
* indices are written to \c hit_object_out.
*
* Shader execution reordering can be requested by using the \c reorder flag.
* When the flag is set, the reordering will use the interesected shape's ID
* as a sorting key. In addtion, an extra reordering hint can be passed in the
* \c reorder_hit argument of which only the the last \c reorder_hint_num_bits
* will be used (starting from the lowest signifcant bit). The hint will serve
* as an extra sorting level for threads that intersected the same shape. The
* hint is optional, it can be discared by setting \c reorder_hint_num_bits to
* 0. If you wish to completely ignore the intersected shape's ID for the
* reodering, \ref jit_optix_reorder is more appropriate. Note that if
* \c JitFlag::ShaderExecutionReordering is not set, the \c reorder flag will
* be ignored.
*
* The \c invoke flag determines whether the closest hit and miss programs are
* executed or not.
*
Expand Down Expand Up @@ -207,6 +219,7 @@ enum OptixHitObjectField {
extern JIT_EXPORT void jit_optix_ray_trace(
uint32_t n_args, uint32_t *args, uint32_t n_hit_object_field,
OptixHitObjectField *hit_object_fields, uint32_t *hit_object_out,
int reorder, uint32_t reorder_hint, uint32_t reorder_hint_num_bits,
int invoke, uint32_t mask, uint32_t pipeline, uint32_t sbt);

/**
Expand All @@ -222,6 +235,32 @@ extern JIT_EXPORT uint32_t jit_optix_sbt_data_load(uint32_t sbt_data_ptr,
VarType type,
uint32_t offset,
uint32_t mask);
/**
* \brief Trigger a reordering of the GPU threads
*
* This operation triggers a call to the Shader Execution Reordering feature of
* the GPU. Its main goal is to perform a hardware-level shuffle of the threads
* such that the per-warp divergence can be reduced. To the user, the shuffle
* is invisible - the order within an array is still preserved.
*
* The \c key argument is the JIT index of a 32-bit unsigned integer array that
* defines a hint that is used during the shuffle, similarly to a sorting key.
* However, only \c num_bits of the hint are considered (starting from the least
* signifcant bit). A maximum of 16 bits can be used.
*
* The \c values argument is an array of JIT indices of size \c n_values. Its
* purpose is to define a set of JIT indices to which the reordering can attach
* itself. These \c values are returned in the \c out argument, but as new JIT
* indices that also encode the reordering operation. Effectively, this provides
* the following guarantee: the reordering will take place if any of the \c out
* indices are used in a kernel.
*
* Note that if \c JitFlag::ShaderExecutionReordering is not set, this function
* is a no-op.
*/
extern JIT_EXPORT void jit_optix_reorder(uint32_t key, uint32_t num_bits,
uint32_t n_values, uint32_t *values,
uint32_t *out);

#if defined(__cplusplus)
}
Expand Down
12 changes: 10 additions & 2 deletions src/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1074,11 +1074,14 @@ void jit_optix_update_sbt(uint32_t index, const OptixShaderBindingTable *sbt) {
void jit_optix_ray_trace(uint32_t n_args, uint32_t *args,
uint32_t n_hit_object_field,
OptixHitObjectField *hit_object_fields,
uint32_t *hit_object_out, int invoke,
uint32_t *hit_object_out,
int reorder, uint32_t reorder_hint,
uint32_t reorder_hint_num_bits, int invoke,
uint32_t mask, uint32_t pipeline, uint32_t sbt) {
lock_guard guard(state.lock);
jitc_optix_ray_trace(n_args, args, n_hit_object_field, hit_object_fields,
hit_object_out, invoke, mask, pipeline, sbt);
hit_object_out, reorder, reorder_hint,
reorder_hint_num_bits, invoke, mask, pipeline, sbt);
}

uint32_t jit_optix_sbt_data_load(uint32_t sbt_data_ptr, VarType type,
Expand All @@ -1087,6 +1090,11 @@ uint32_t jit_optix_sbt_data_load(uint32_t sbt_data_ptr, VarType type,
return jitc_optix_sbt_data_load(sbt_data_ptr, type, offset, mask);
}

void jit_optix_reorder(uint32_t key, uint32_t num_bits, uint32_t n_values,
uint32_t *values, uint32_t *out) {
lock_guard guard(state.lock);
return jitc_optix_reorder(key, num_bits, n_values, values, out);
}
#endif

void jit_llvm_ray_trace(uint32_t func, uint32_t scene, int shadow_ray,
Expand Down
47 changes: 47 additions & 0 deletions src/cuda_eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ static void jitc_cuda_render_trace(const Variable *v,
const Variable *valid,
const Variable *pipeline,
const Variable *sbt);

static void jitc_cuda_render_reorder(const Variable *, const Variable *);
#endif

void jitc_cuda_assemble(ThreadState *ts, ScheduledGroup group,
Expand Down Expand Up @@ -958,6 +960,11 @@ static void jitc_cuda_render(Variable *v) {
case VarKind::CondOutput: // No output
break;

#if defined(DRJIT_ENABLE_OPTIX)
case VarKind::ReorderThread:
jitc_cuda_render_reorder(v, a0);
break;
#endif

default:
jitc_fail("jitc_cuda_render(): unhandled variable kind \"%s\"!",
Expand Down Expand Up @@ -1036,6 +1043,9 @@ static void jitc_cuda_render_trace(const Variable *v,
" mov.u32 $v_count, $u;\n",
v, v, v, v, payload_count);

// =====================================================
// 1. Traverse
// =====================================================
put(" call (");
for (uint32_t i = 0; i < 32; ++i)
fmt("$v_out_$u$s", v, i, i + 1 < 32 ? ", " : "");
Expand All @@ -1058,6 +1068,25 @@ static void jitc_cuda_render_trace(const Variable *v,

put(");\n");

// =====================================================
// 2. Reorder
// =====================================================
if (td->reorder && jit_flag(JitFlag::ShaderExecutionReordering)) {
if (td->reorder_hint_num_bits == 0) {
fmt(" call (), _optix_hitobject_reorder, ($v_z, $v_z);\n", v, v);
} else {
fmt(" .reg .u32 $v_hint_bits;\n"
" mov.u32 $v_hint_bits, $u;\n"
" call (), _optix_hitobject_reorder, ($v, $v_hint_bits);\n",
v,
v, td->reorder_hint_num_bits,
jitc_var(td->reorder_hint), v);
}
}

// =====================================================
// 3. Get HitObject fields
// =====================================================
size_t n_fields = td->hit_object_fields.size();
for (uint32_t i = 0; i < n_fields; ++i) {
uint32_t field_i = td->hit_object_fields[i];
Expand Down Expand Up @@ -1123,6 +1152,9 @@ static void jitc_cuda_render_trace(const Variable *v,
}
}

// =====================================================
// 4. Invoke miss & closest hit programs
// =====================================================
if (td->invoke) {
put(" call (");
for (uint32_t i = 0; i < 32; ++i)
Expand All @@ -1147,6 +1179,21 @@ static void jitc_cuda_render_trace(const Variable *v,
if (some_masked)
fmt("\nl_masked_$u:\n", v->reg_index);
}

static void jitc_cuda_render_reorder(const Variable *v, const Variable *key) {
if (jit_flag(JitFlag::ShaderExecutionReordering)) {
// Create an outgoing Nop hit object on all lanes, to overwrite any existing
// hit object which could influence the reordering.
put(" call (), _optix_hitobject_make_nop, ();\n");

fmt(" .reg .u32 $v_hint_bits;\n"
" mov.u32 $v_hint_bits, $u;\n"
" call (), _optix_hitobject_reorder, ($v, $v_hint_bits);\n",
v,
v, v->literal,
key, v);
}
}
#endif

/// Virtual function call code generation -- CUDA/PTX-specific bits
Expand Down
3 changes: 3 additions & 0 deletions src/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,9 @@ enum class VarKind : uint32_t {
// Write an element to a variable array
ArrayWrite,

// Shader execution reordering (OptiX)
ReorderThread,

// Denotes the number of different node types
Count
};
Expand Down
19 changes: 12 additions & 7 deletions src/optix.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,19 +45,23 @@ extern void jitc_optix_update_sbt(uint32_t index, const OptixShaderBindingTable
enum class OptixHitObjectField: uint32_t;

/// Insert a function call to optixTrace into the program
extern void jitc_optix_ray_trace(uint32_t n_args, uint32_t *args,
uint32_t n_hit_object_field,
OptixHitObjectField *hit_object_fields,
uint32_t *hit_object_out, int invoke,
uint32_t mask, uint32_t pipeline,
uint32_t sbt);
extern void jitc_optix_ray_trace(
uint32_t n_args, uint32_t *args, uint32_t n_hit_object_field,
OptixHitObjectField *hit_object_fields, uint32_t *hit_object_out,
int reorder, uint32_t reorder_hint, uint32_t reorder_hint_num_bits,
int invoke, uint32_t mask, uint32_t pipeline, uint32_t sbt);

// Read data from the SBT data buffer
extern JIT_EXPORT uint32_t jitc_optix_sbt_data_load(uint32_t sbt_data_ptr,
VarType type,
uint32_t offset,
uint32_t mask);

// Trigger a reordering of the GPU threads
extern JIT_EXPORT void jitc_optix_reorder(uint32_t key, uint32_t num_bits,
uint32_t n_values, uint32_t *values,
uint32_t *out);

/// Compile an OptiX kernel
extern bool jitc_optix_compile(ThreadState *ts, const char *buffer,
size_t buffer_size, const char *kernel_name,
Expand All @@ -71,4 +75,5 @@ extern void jitc_optix_launch(ThreadState *ts, const Kernel &kernel,
uint32_t size, const void *args, uint32_t n_args);

/// Optional: set the desired launch size
extern void jitc_optix_set_launch_size(uint32_t width, uint32_t height, uint32_t samples);
extern void jitc_optix_set_launch_size(uint32_t width, uint32_t height,
uint32_t samples);
100 changes: 98 additions & 2 deletions src/optix_core.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,8 +500,10 @@ void jitc_optix_launch(ThreadState *ts, const Kernel &kernel,
void jitc_optix_ray_trace(uint32_t n_args, uint32_t *args,
uint32_t n_hit_object_field,
OptixHitObjectField *hit_object_fields,
uint32_t *hit_object_out, int invoke, uint32_t mask,
uint32_t pipeline, uint32_t sbt) {
uint32_t *hit_object_out, int reorder,
uint32_t reorder_hint, uint32_t reorder_hint_num_bits,
int invoke, uint32_t mask, uint32_t pipeline,
uint32_t sbt) {
VarType types[]{ VarType::UInt64, VarType::Float32, VarType::Float32,
VarType::Float32, VarType::Float32, VarType::Float32,
VarType::Float32, VarType::Float32, VarType::Float32,
Expand Down Expand Up @@ -561,6 +563,16 @@ void jitc_optix_ray_trace(uint32_t n_args, uint32_t *args,
if (hit_object_fields[i] >= OptixHitObjectField::Count)
jitc_raise("jit_optix_ray_trace(): unknown hit object field!");


if (reorder) {
if (reorder_hint_num_bits > 16)
jitc_fail("jit_optix_ray_trace(): a maximum of 16 bits can be used for "
"the reordering key!");
if ((VarType) jitc_var(reorder_hint)->type != VarType::UInt32)
jitc_raise("jit_optix_ray_trace(): 'reorder_hint' must be an "
"unsigned 32-bit array.");
}

// Potentially apply any masks on the mask stack
Ref valid = steal(jitc_var_mask_apply(mask, size));

Expand All @@ -574,6 +586,9 @@ void jitc_optix_ray_trace(uint32_t n_args, uint32_t *args,
// Fill payload information for node
TraceData *td = new TraceData();
td->invoke = invoke;
td->reorder = reorder;
td->reorder_hint = reorder_hint;
td->reorder_hint_num_bits = reorder_hint_num_bits;
td->indices.reserve(n_args);
for (uint32_t i = 0; i < n_args; ++i) {
uint32_t id = args[i];
Expand All @@ -592,6 +607,11 @@ void jitc_optix_ray_trace(uint32_t n_args, uint32_t *args,
Variable *v = jitc_var(index);
v->optix = 1;

if (reorder && reorder_hint_num_bits > 0) {
v->dep[3] = reorder_hint;
jitc_var_inc_ref(reorder_hint);
}

// Extract payload values
for (uint32_t i = 0; i < np; ++i)
args[15 + i] = jitc_var_new_node_1(
Expand Down Expand Up @@ -655,6 +675,82 @@ uint32_t jitc_optix_sbt_data_load(uint32_t sbt_data_ptr, VarType type,
mask, jitc_var(mask), offset);
}

void jitc_optix_reorder(uint32_t key, uint32_t num_bits, uint32_t n_values,
uint32_t *values, uint32_t *out) {
Variable *v_key = jitc_var(key);

if ((JitBackend) v_key->backend != JitBackend::CUDA) {
for (uint32_t i = 0; i < n_values; ++i) {
jitc_var_inc_ref(values[i]);
out[i] = values[i];
}
return;
}

if ((VarType) v_key->type != VarType::UInt32)
jitc_raise("jit_optix_reorder(): 'key' must be an unsigned 32-bit array.");

uint32_t size = 0;
bool symbolic = v_key->symbolic;
bool dirty = v_key->is_dirty();
for (uint32_t i = 0; i <= n_values; ++i) {
uint32_t index = (i < n_values) ? values[i] : key;
Variable *v = jitc_var(index);
size = std::max(size, v->size);
symbolic |= (bool) v->symbolic;
dirty |= v->is_dirty();
}

for (uint32_t i = 0; i <= n_values; ++i) {
uint32_t index = (i < n_values) ? values[i] : key;
const Variable *v = jitc_var(index);
if (v->size != 1 && v->size != size)
jitc_raise("jit_optix_reorder(): incompatible array sizes!");
}

if (dirty) {
jitc_eval(thread_state(JitBackend::CUDA));

for (uint32_t i = 0; i < n_values; ++i) {
if (jitc_var(values[i])->is_dirty())
jitc_raise_dirty_error(values[i]);
}
}

if (num_bits < 1)
jitc_fail("jit_optix_reorder(): the key must be at least one bit!");
if (num_bits > 16)
jitc_fail("jit_optix_reorder(): a maximum of 16 bits can be used for "
"the key!");

// Guarantee that the reordering is assembled after anything that precedes this operation
jitc_new_scope(JitBackend::CUDA);

uint32_t reorder = jitc_var_new_node_1(
JitBackend::CUDA, VarKind::ReorderThread, VarType::Void, size, symbolic,
key, v_key, num_bits);
Variable *v_reorder = jitc_var(reorder);
v_reorder->optix = 1;

// Guarantee that the reordering is assembled before anything that follows
jitc_new_scope(JitBackend::CUDA);

for (uint32_t i = 0; i < n_values; ++i) {
// Values are just a `Bitcast` of their original value (no-op). We
// additionally add the reodering node as the second dependency. This
// guarantees that it will be detected during `jitc_var_traverse` and
// that it will only be assembled once, despite it not being directly
// used when assembling a `Bitcast` node.
Variable *v_value = jitc_var(values[i]);
out[i] = jitc_var_new_node_2(JitBackend::CUDA, VarKind::Bitcast,
(VarType) v_value->type, v_value->size,
v_value->symbolic, values[i], v_value,
reorder, v_reorder);
}

jitc_var_dec_ref(reorder, v_reorder);
}

void jitc_optix_check_impl(OptixResult errval, const char *file,
const int line) {
if (unlikely(errval != 0)) {
Expand Down
3 changes: 3 additions & 0 deletions src/trace.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ struct TraceData {
#if defined(DRJIT_ENABLE_OPTIX)
std::vector<uint32_t> hit_object_fields;
bool invoke;
bool reorder;
uint32_t reorder_hint;
uint32_t reorder_hint_num_bits;
#endif

~TraceData() {
Expand Down
Loading