Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions dlib/cuda/cpu_dlib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3267,6 +3267,7 @@ namespace dlib
resizable_tensor& cumulative_halting,
resizable_tensor& remainders,
resizable_tensor& n_steps,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand All @@ -3281,6 +3282,7 @@ namespace dlib
float* cum_halt = cumulative_halting.host();
float* remain = remainders.host();
float* steps = n_steps.host();
float* eff_weights = effective_weights.host();

for (long pos = 0; pos < batch_size * seq_len; ++pos) {
if (cum_halt[pos] < halt_threshold) {
Expand All @@ -3294,6 +3296,7 @@ namespace dlib
cum_halt[pos] += effective;
remain[pos] -= effective;
steps[pos] = static_cast<float>(current_step + 1);
eff_weights[pos] += effective;

for (long c = 0; c < num_channels; ++c) {
for (long d = 0; d < d_model; ++d) {
Expand All @@ -3309,6 +3312,7 @@ namespace dlib
resizable_tensor& output,
const tensor& input_data,
const tensor& remainders,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand All @@ -3318,13 +3322,16 @@ namespace dlib
const float* in_ptr = input_data.host();
const float* remain = remainders.host();
float* out_ptr = output.host();
float* eff_weights = effective_weights.host();

for (long pos = 0; pos < batch_size * seq_len; ++pos) {
float r = remain[pos];
if (r > 1e-6f) {
const long n = pos / seq_len;
const long s = pos % seq_len;

eff_weights[pos] += r;

for (long c = 0; c < num_channels; ++c) {
for (long d = 0; d < d_model; ++d) {
const long idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
Expand Down
2 changes: 2 additions & 0 deletions dlib/cuda/cpu_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ namespace dlib
resizable_tensor& cumulative_halting,
resizable_tensor& remainders,
resizable_tensor& n_steps,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand All @@ -567,6 +568,7 @@ namespace dlib
resizable_tensor& output,
const tensor& input_data,
const tensor& remainders,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand Down
9 changes: 9 additions & 0 deletions dlib/cuda/cuda_dlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -2819,6 +2819,7 @@ namespace dlib
float* cumulative_halting,
float* remainders,
float* n_steps,
float* effective_weights,
size_t batch_size,
size_t seq_len,
size_t d_model,
Expand All @@ -2841,6 +2842,7 @@ namespace dlib
cumulative_halting[pos] += effective;
remainders[pos] -= effective;
n_steps[pos] = static_cast<float>(current_step + 1);
effective_weights[pos] += effective;

for (size_t c = 0; c < num_channels; ++c) {
for (size_t d = 0; d < d_model; ++d) {
Expand All @@ -2859,6 +2861,7 @@ namespace dlib
resizable_tensor& cumulative_halting,
resizable_tensor& remainders,
resizable_tensor& n_steps,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand All @@ -2877,6 +2880,7 @@ namespace dlib
cumulative_halting.device(),
remainders.device(),
n_steps.device(),
effective_weights.device(),
batch_size,
seq_len,
d_model,
Expand All @@ -2889,6 +2893,7 @@ namespace dlib
float* output,
const float* input_data,
const float* remainders,
float* effective_weights,
size_t batch_size,
size_t seq_len,
size_t d_model,
Expand All @@ -2902,6 +2907,8 @@ namespace dlib
const size_t n = pos / seq_len;
const size_t s = pos % seq_len;

effective_weights[pos] += r;

for (size_t c = 0; c < num_channels; ++c) {
for (size_t d = 0; d < d_model; ++d) {
const size_t idx = ((n * num_channels + c) * seq_len + s) * d_model + d;
Expand All @@ -2916,6 +2923,7 @@ namespace dlib
resizable_tensor& output,
const tensor& input_data,
const tensor& remainders,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand All @@ -2929,6 +2937,7 @@ namespace dlib
output.device(),
input_data.device(),
remainders.device(),
effective_weights.device(),
batch_size,
seq_len,
d_model,
Expand Down
2 changes: 2 additions & 0 deletions dlib/cuda/cuda_dlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,7 @@ namespace dlib
resizable_tensor& cumulative_halting,
resizable_tensor& remainders,
resizable_tensor& n_steps,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand All @@ -639,6 +640,7 @@ namespace dlib
resizable_tensor& output,
const tensor& input_data,
const tensor& remainders,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand Down
10 changes: 6 additions & 4 deletions dlib/cuda/tensor_tools.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,7 @@ namespace dlib { namespace tt
resizable_tensor& cumulative_halting,
resizable_tensor& remainders,
resizable_tensor& n_steps,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand All @@ -1450,28 +1451,29 @@ namespace dlib { namespace tt
{
#ifdef DLIB_USE_CUDA
cuda::update_act_state(output, input_data, halt_probs, cumulative_halting, remainders,
n_steps, batch_size, seq_len, d_model, num_channels, halt_threshold, current_step);
n_steps, effective_weights, batch_size, seq_len, d_model, num_channels, halt_threshold, current_step);
#else
cpu::update_act_state(output, input_data, halt_probs, cumulative_halting, remainders,
n_steps, batch_size, seq_len, d_model, num_channels, halt_threshold, current_step);
n_steps, effective_weights, batch_size, seq_len, d_model, num_channels, halt_threshold, current_step);
#endif
}

void finalize_act_output(
resizable_tensor& output,
const tensor& input_data,
const tensor& remainders,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
long num_channels
)
{
#ifdef DLIB_USE_CUDA
cuda::finalize_act_output(output, input_data, remainders,
cuda::finalize_act_output(output, input_data, remainders, effective_weights,
batch_size, seq_len, d_model, num_channels);
#else
cpu::finalize_act_output(output, input_data, remainders,
cpu::finalize_act_output(output, input_data, remainders, effective_weights,
batch_size, seq_len, d_model, num_channels);
#endif
}
Expand Down
9 changes: 6 additions & 3 deletions dlib/cuda/tensor_tools.h
Original file line number Diff line number Diff line change
Expand Up @@ -2428,6 +2428,7 @@ namespace dlib { namespace tt
resizable_tensor& cumulative_halting,
resizable_tensor& remainders,
resizable_tensor& n_steps,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand All @@ -2445,12 +2446,12 @@ namespace dlib { namespace tt
- input_data.nc() == d_model
- output has the same dimensions as input_data
- halt_probs.size() == batch_size * seq_len
- cumulative_halting.size() == remainders.size() == n_steps.size() == batch_size * seq_len
- cumulative_halting.size() == remainders.size() == n_steps.size() == effective_weights.size() == batch_size * seq_len
ensures
- Core ACT update step that accumulates weighted outputs:
- Updates ACT state for all positions
- Accumulates weighted outputs: output += α_t^n * input_data
- Updates cumulative_halting, remainders, and n_steps
- Updates cumulative_halting, remainders, n_steps, and effective_weights
- batch_size: number of samples in the batch
- seq_len: sequence length (number of positions to process)
- d_model: model dimension per channel
Expand All @@ -2463,6 +2464,7 @@ namespace dlib { namespace tt
resizable_tensor& output,
const tensor& input_data,
const tensor& remainders,
resizable_tensor& effective_weights,
long batch_size,
long seq_len,
long d_model,
Expand All @@ -2475,10 +2477,11 @@ namespace dlib { namespace tt
- input_data.nr() == seq_len
- input_data.nc() == d_model
- output has the same dimensions as input_data
- remainders.size() == batch_size * seq_len
- remainders.size() == effective_weights.size() == batch_size * seq_len
ensures
- Finalizes ACT output by adding remainder contributions:
- Adds final remainder contributions: output += ρ_t * input_data
- Updates effective_weights with remainder values
- Applied only to positions with significant remainder (> 1e-6)
- batch_size: number of samples in the batch
- seq_len: sequence length (number of positions to process)
Expand Down
30 changes: 2 additions & 28 deletions dlib/dnn/layers.h
Original file line number Diff line number Diff line change
Expand Up @@ -5857,29 +5857,11 @@ namespace dlib
halting_probs_, logits_, input, params,
batch_size_, seq_len_, feature_dim_);

// Capture effective weights before state update
const float* p_halt = halting_probs_.host();
const float* cum_halt = cum_halt_ptr;
const float* remainders = remainders_ptr;
float* true_weights = true_effective_weights_.host();

for (long pos = 0; pos < total_positions; ++pos) {
if (cum_halt[pos] < halt_threshold_) {
float p = p_halt[pos];
float r = remainders[pos];

// Compute effective weight: alpha_t^n = min(p * rho, theta - h_t^(n-1))
float effective = std::min(p * r, halt_threshold_ - cum_halt[pos]);

// Store for backward pass
true_weights[pos] += effective;
}
}

// Update ACT state and accumulate weighted outputs
tt::update_act_state(
output, input, halting_probs_,
cumulative_halting_, remainders_, n_steps_,
true_effective_weights_,
batch_size_, seq_len_, d_model_, num_channels_,
halt_threshold_, step
);
Expand All @@ -5891,17 +5873,9 @@ namespace dlib
// Finalize with remainder contributions
tt::finalize_act_output(
output, input, remainders_,
true_effective_weights_,
batch_size_, seq_len_, d_model_, num_channels_);

// Add remainder weights for gradient computation
const float* final_remainders = remainders_.host();
float* true_weights = true_effective_weights_.host();
for (long pos = 0; pos < total_positions; ++pos) {
if (final_remainders[pos] > 1e-6f) {
true_weights[pos] += final_remainders[pos];
}
}

// Compute statistics for monitoring and regularization
compute_ponder_stats();
}
Expand Down
Loading