Skip to content

Commit b25effa

Browse files
Addressed comments from Wenzel
1 parent a244172 commit b25effa

File tree

4 files changed

+19
-14
lines changed

4 files changed

+19
-14
lines changed

include/drjit-core/jit.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,8 +1176,7 @@ extern JIT_EXPORT uint32_t jit_var_data(uint32_t index, void **ptr_out);
11761176
/// Query the size of a given variable
11771177
extern JIT_EXPORT size_t jit_var_size(uint32_t index);
11781178

1179-
/// Query the size of a given variable, as an opaque variable. This also
1180-
/// notifies the ThreadState, allowing us to record this for frozen functions.
1179+
/// Query the size of a given variable, as an opaque variable.
11811180
extern JIT_EXPORT uint32_t jit_var_opaque_width(uint32_t index);
11821181

11831182
/// Query the type of a given variable

src/api.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -616,16 +616,17 @@ uint32_t jit_var_opaque_width(uint32_t index) {
616616

617617
lock_guard guard(state.lock);
618618

619-
Variable *var = jitc_var(index);
620-
uint32_t var_size = var->size;
619+
Variable *var = jitc_var(index);
620+
JitBackend backend = (JitBackend) var->backend;
621+
uint32_t var_size = var->size;
621622

622623
// The variable has to be evaluated, to notify the ThreadState
623624
jitc_var_eval(index);
624625

625-
uint32_t width_index = jitc_var_literal(
626-
(JitBackend) var->backend, VarType::UInt32, &var_size, 1, true);
626+
uint32_t width_index =
627+
jitc_var_literal(backend, VarType::UInt32, &var_size, 1, true);
627628

628-
ThreadState *ts = thread_state(var->backend);
629+
ThreadState *ts = thread_state(backend);
629630
ts->notify_opaque_width(index, width_index);
630631

631632
return width_index;

src/internal.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,11 @@ struct ThreadState : public ThreadStateBase {
720720
virtual void reduce_expanded(VarType vt, ReduceOp op, void *data,
721721
uint32_t exp, uint32_t size) = 0;
722722

723+
/// Some kernels use the width of an array in a computation. When using the
724+
/// kernel freezing feature, this requires special precautions to ensure
725+
/// that the resulting capture remains usable with different array sizes.
726+
/// This notification function exists so that this special-case handling can
727+
/// be realized.
723728
virtual void notify_opaque_width(uint32_t index, uint32_t width_index);
724729

725730
/// Notify the \c ThreadState that \c jitc_free has been called on a pointer.

src/record_ts.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -476,11 +476,11 @@ void RecordThreadState::barrier() {
476476
}
477477

478478
void RecordThreadState::notify_opaque_width(uint32_t index,
479-
uint32_t width_index) {
479+
uint32_t width_index) {
480480
if (!paused()) {
481481
uint32_t start = m_recording.dependencies.size();
482-
Variable *v1 = jitc_var(index);
483-
Variable *v2 = jitc_var(width_index);
482+
Variable *v1 = jitc_var(index);
483+
Variable *v2 = jitc_var(width_index);
484484
add_in_param(v1->data, (VarType) v1->type);
485485
add_out_param(v2->data, VarType::UInt32);
486486
uint32_t end = m_recording.dependencies.size();
@@ -492,13 +492,13 @@ void RecordThreadState::notify_opaque_width(uint32_t index,
492492
}
493493
}
494494

495-
int Recording::replay_opaque_width(Operation &op){
495+
int Recording::replay_opaque_width(Operation &op) {
496496

497497
uint32_t dependency_index = op.dependency_range.first;
498-
AccessInfo in_info = dependencies[dependency_index];
499-
AccessInfo out_info = dependencies[dependency_index + 1];
498+
AccessInfo in_info = dependencies[dependency_index];
499+
AccessInfo out_info = dependencies[dependency_index + 1];
500500

501-
ReplayVariable &in_var = replay_variables[in_info.slot];
501+
ReplayVariable &in_var = replay_variables[in_info.slot];
502502
ReplayVariable &out_var = replay_variables[out_info.slot];
503503

504504
out_var.alloc(backend, 1, out_info.vtype);

0 commit comments

Comments
 (0)