Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 74 additions & 36 deletions src/liger_kernel/ops/backends/_ascend/ops/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@ def calculate_tile_count_2d(batch_size, seq_len, num_cores):

def compute_block_size_softmax(seq_vocab_size):
"""Determine optimal block size for selective log-softmax kernel."""
multiplier = 6.0
# Only one BLOCK_N buffer needed in UB (logits load); scalars stay in registers.
# Conservative: 3.0 accounts for compiler register allocation margin.
# UB capacity (192KB, 0.9 safety) / (3.0 * 4B) = ~15K -> power-of-2 = 8192.
# Capped by min(N, 8192) in tiling strategy, so typical vocab (32K-128K) gets 8192.
multiplier = 3.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((seq_vocab_size,),), tiling_dims=(0,)
)
Expand All @@ -48,7 +52,11 @@ def compute_block_size_softmax(seq_vocab_size):

def compute_block_size_forward(seq_vocab_size):
"""Determine optimal block size for forward pass kernel."""
multiplier = 10.0
# Only one BLOCK_N buffer needed in UB (logits load); scalars stay in registers.
# 4.0 provides safety margin while enabling larger BLOCK_N vs legacy 10.0.
# UB capacity (192KB, 0.9 safety) / (4.0 * 4B) = ~11K -> power-of-2 = 8192.
# Capped by min(N, 8192), so typical vocab gets up to 8192 (vs legacy 2048).
multiplier = 4.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((seq_vocab_size,),), tiling_dims=(0,)
)
Expand All @@ -59,7 +67,10 @@ def compute_block_size_forward(seq_vocab_size):

def compute_block_size_backward(seq_vocab_size):
"""Determine optimal block size for backward pass kernel."""
multiplier = 12.0
# Backward has higher register pressure: needs logits + probs + dlogits_chunk in UB.
# 8.0 is a balanced reduction from legacy 12.0.
# UB capacity (192KB, 0.9 safety) / (8.0 * 4B) = ~5.5K -> power-of-2 = 4096.
multiplier = 8.0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9, dtype_size=4, memory_multiplier=multiplier, shapes=((seq_vocab_size,),), tiling_dims=(0,)
)
Expand All @@ -78,7 +89,7 @@ def _selective_log_softmax_kernel(
stride_input_ids_b,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 2048,
BLOCK_N: tl.constexpr = 4096,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
Expand All @@ -89,6 +100,9 @@ def _selective_log_softmax_kernel(
start_token = batch_start + pid_l
stride = num_progs_l

# Precompute 1/TEMPERATURE to replace repeated division with multiplication
inv_temp = 1.0 / TEMPERATURE

for token_idx in tl.range(start_token, batch_end, stride):
off_b = token_idx // L
off_l = token_idx % L
Expand All @@ -106,17 +120,19 @@ def _selective_log_softmax_kernel(

m_i = float("-inf")
l_i = 0.0
for start in range(0, N, BLOCK_N):
# Use tl.static_range for inner softmax loop (BLOCK_N is constexpr)
for start in tl.static_range(0, N, BLOCK_N):
cols = start + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS_local + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
cols_mask = cols < N
logits = tl.load(LOGITS_local + cols, mask=cols_mask, other=float("-inf")).to(tl.float32) * inv_temp
new_m_i = tl.maximum(m_i, tl.max(logits))
alpha = tl.exp(m_i - new_m_i)
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
m_i = new_m_i
lse = m_i + tl.log(l_i)

ids = tl.load(INPUT_IDS_local)
x = tl.load(LOGITS_local + ids).to(tl.float32) / TEMPERATURE
x = tl.load(LOGITS_local + ids).to(tl.float32) * inv_temp
logp = x - lse
tl.store(LOG_P_local, logp)

Expand Down Expand Up @@ -146,7 +162,7 @@ def _grpo_loss_fwd_kernel(
USE_BIAS_CORRECTION_KL: tl.constexpr,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 2048,
BLOCK_N: tl.constexpr = 4096,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
Expand All @@ -157,6 +173,9 @@ def _grpo_loss_fwd_kernel(
start_token = batch_start + pid_l
stride = num_progs_l

# Precompute 1/TEMPERATURE to replace repeated division with multiplication
inv_temp = 1.0 / TEMPERATURE

for token_idx in tl.range(start_token, batch_end, stride):
off_b = token_idx // L
off_l = token_idx % L
Expand All @@ -177,17 +196,20 @@ def _grpo_loss_fwd_kernel(

m_i = float("-inf")
l_i = 0.0
for start in range(0, N, BLOCK_N):
# Use tl.static_range for inner softmax loop (BLOCK_N is constexpr)
# to give the compiler unrolling hints for better instruction scheduling
for start in tl.static_range(0, N, BLOCK_N):
cols = start + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS_local + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
cols_mask = cols < N
logits = tl.load(LOGITS_local + cols, mask=cols_mask, other=float("-inf")).to(tl.float32) * inv_temp
new_m_i = tl.maximum(m_i, tl.max(logits))
alpha = tl.exp(m_i - new_m_i)
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
m_i = new_m_i
lse = m_i + tl.log(l_i)

idx = tl.load(INPUT_IDS_local)
x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE
idx = tl.load(INPUT_IDS_local).to(tl.int32)
x = tl.load(LOGITS_local + idx).to(tl.float32) * inv_temp
logp = x - lse
if OLD_LOGP is None:
old_logp = logp
Expand Down Expand Up @@ -263,7 +285,7 @@ def _grpo_loss_fwd_kernel_seq(
USE_BIAS_CORRECTION_KL: tl.constexpr,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 2048,
BLOCK_N: tl.constexpr = 4096,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
Expand All @@ -274,6 +296,9 @@ def _grpo_loss_fwd_kernel_seq(
start_token = batch_start + pid_l
stride = num_progs_l

# Precompute 1/TEMPERATURE to replace repeated division with multiplication
inv_temp = 1.0 / TEMPERATURE

for token_idx in tl.range(start_token, batch_end, stride):
off_b = token_idx // L
off_l = token_idx % L
Expand All @@ -297,17 +322,19 @@ def _grpo_loss_fwd_kernel_seq(

m_i = float("-inf")
l_i = 0.0
for start in range(0, N, BLOCK_N):
# Use tl.static_range for inner softmax loop (BLOCK_N is constexpr)
for start in tl.static_range(0, N, BLOCK_N):
cols = start + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS_local + cols, mask=cols < N, other=float("-inf")).to(tl.float32) / TEMPERATURE
cols_mask = cols < N
logits = tl.load(LOGITS_local + cols, mask=cols_mask, other=float("-inf")).to(tl.float32) * inv_temp
new_m_i = tl.maximum(m_i, tl.max(logits))
alpha = tl.exp(m_i - new_m_i)
l_i = l_i * alpha + tl.sum(tl.exp(logits - new_m_i))
m_i = new_m_i
lse = m_i + tl.log(l_i)

idx = tl.load(INPUT_IDS_local)
x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE
idx = tl.load(INPUT_IDS_local).to(tl.int32)
x = tl.load(LOGITS_local + idx).to(tl.float32) * inv_temp
logp = x - lse

coef_1 = tl.load(COEF_1_local).to(tl.float32)
Expand Down Expand Up @@ -368,7 +395,7 @@ def _grpo_loss_bwd_kernel_seq(
loss_stride1,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 2048,
BLOCK_N: tl.constexpr = 4096,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
Expand All @@ -379,6 +406,9 @@ def _grpo_loss_bwd_kernel_seq(
start_token = batch_start + pid_l
stride = num_progs_l

# Precompute 1/TEMPERATURE to replace repeated division with multiplication
inv_temp = 1.0 / TEMPERATURE

for token_idx in tl.range(start_token, batch_end, stride):
off_b = token_idx // L
off_l = token_idx % L
Expand All @@ -392,7 +422,7 @@ def _grpo_loss_bwd_kernel_seq(
should_process = not_skip

if should_process == 0:
for start in range(0, N, BLOCK_N):
for start in tl.static_range(0, N, BLOCK_N):
cols = tl.arange(0, BLOCK_N) + start
tl.store(DLOGITS_local + cols, 0.0, mask=cols < N)
else:
Expand All @@ -411,8 +441,8 @@ def _grpo_loss_bwd_kernel_seq(
coef_1 = tl.load(COEF_1_local).to(tl.float32)
seq_len = tl.load(SEQ_LEN_local).to(tl.float32)

idx = tl.load(INPUT_IDS_local)
x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE
idx = tl.load(INPUT_IDS_local).to(tl.int32)
x = tl.load(LOGITS_local + idx).to(tl.float32) * inv_temp
logp = x - lse

advantage = tl.load(ADVANTAGES_local).to(tl.float32)
Expand Down Expand Up @@ -442,13 +472,16 @@ def _grpo_loss_bwd_kernel_seq(
else:
dlogp += BETA * (1 - tl.exp(ref_logp - logp)) * dloss

dlogp = dlogp / TEMPERATURE
for start_n in tl.range(0, N, BLOCK_N):
dlogp = dlogp * inv_temp
# Use tl.static_range for inner loop (BLOCK_N is constexpr)
for start_n in tl.static_range(0, N, BLOCK_N):
cols = start_n + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS_local + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
cols_mask = cols < N
logits = tl.load(LOGITS_local + cols, mask=cols_mask, other=-float("inf")).to(tl.float32) * inv_temp
probs = tl.exp(logits - lse)
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
tl.store(DLOGITS_local + cols, dlogits, mask=cols < N)
cols_idx = cols == idx
dlogits = (cols_idx - probs) * dlogp
tl.store(DLOGITS_local + cols, dlogits, mask=cols_mask)


@triton.jit
Expand Down Expand Up @@ -477,7 +510,7 @@ def _grpo_loss_bwd_kernel(
loss_stride1,
L: tl.constexpr,
N: tl.constexpr,
BLOCK_N: tl.constexpr = 2048,
BLOCK_N: tl.constexpr = 4096,
):
pid_b = tl.program_id(0)
pid_l = tl.program_id(1)
Expand All @@ -488,6 +521,9 @@ def _grpo_loss_bwd_kernel(
start_token = batch_start + pid_l
stride = num_progs_l

# Precompute 1/TEMPERATURE to replace repeated division with multiplication
inv_temp = 1.0 / TEMPERATURE

for token_idx in tl.range(start_token, batch_end, stride):
off_b = token_idx // L
off_l = token_idx % L
Expand All @@ -501,7 +537,7 @@ def _grpo_loss_bwd_kernel(
should_process = not_skip

if should_process == 0:
for start in range(0, N, BLOCK_N):
for start in tl.static_range(0, N, BLOCK_N):
cols = tl.arange(0, BLOCK_N) + start
tl.store(DLOGITS_local + cols, 0.0, mask=cols < N)
else:
Expand All @@ -514,8 +550,8 @@ def _grpo_loss_bwd_kernel(
dloss = tl.load(DLOSS_local).to(tl.float32)
lse = tl.load(LSE_local).to(tl.float32)

idx = tl.load(INPUT_IDS_local)
x = tl.load(LOGITS_local + idx).to(tl.float32) / TEMPERATURE
idx = tl.load(INPUT_IDS_local).to(tl.int32)
x = tl.load(LOGITS_local + idx).to(tl.float32) * inv_temp
logp = x - lse
if OLD_LOGP is None:
old_logp = logp
Expand Down Expand Up @@ -563,14 +599,16 @@ def _grpo_loss_bwd_kernel(
else:
dlogp += BETA * (1 - tl.exp(ref_logp - logp))

dlogp = dlogp * dloss / TEMPERATURE
tl.debug_barrier()
for start_n in tl.range(0, N, BLOCK_N):
dlogp = dlogp * dloss * inv_temp
# Use tl.static_range for inner loop (BLOCK_N is constexpr)
for start_n in tl.static_range(0, N, BLOCK_N):
cols = start_n + tl.arange(0, BLOCK_N)
logits = tl.load(LOGITS_local + cols, mask=cols < N, other=-float("inf")).to(tl.float32) / TEMPERATURE
cols_mask = cols < N
logits = tl.load(LOGITS_local + cols, mask=cols_mask, other=-float("inf")).to(tl.float32) * inv_temp
probs = tl.exp(logits - lse)
dlogits = tl.where(cols == idx, 1 - probs, -probs) * dlogp
tl.store(DLOGITS_local + cols, dlogits, mask=cols < N)
cols_idx = cols == idx
dlogits = (cols_idx - probs) * dlogp
tl.store(DLOGITS_local + cols, dlogits, mask=cols_mask)


@torch.no_grad
Expand Down
Loading