Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions src/liger_kernel/ops/backends/_ascend/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,64 @@ def liger_cross_entropy_forward_kernel(
predicted_tokens_row_ptr = predicted_tokens_row_ptr + predicted_tokens_stride


@triton.jit
def liger_cross_entropy_forward_kernel_plain(
X_ptr,
X_stride,
Y_ptr,
loss_ptr,
n_cols,
n_rows,
ce_stats_ptr,
ignore_index,
reduction: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Plain CE forward fast path (no weight/softcap/smoothing/z-loss/metrics).
1 program processes 1 row to maximize parallelism (unlike the generic kernel
which chunks rows per program for flexibility).
Writes per-row loss to ``loss_ptr`` (float32 recommended).
"""
row = tl.program_id(0)
if row >= n_rows:
return

row_i64 = row.to(tl.int64)
y = tl.load(Y_ptr + row_i64)

# Pre-load mean scaling. For sum/none we treat scale as 1.0.
inv_n = tl.load(ce_stats_ptr + 0).cast(tl.float32)

if y == ignore_index:
tl.store(loss_ptr + row_i64, 0.0)
return
logits_row_ptr = X_ptr + row_i64 * X_stride
m = float("-inf")
d = 0.0
# Keep x[y] as a direct load outside the scan. Folding it into the block
# loop adds extra selection state and tends to increase UB/register pressure
# on Ascend.
x_y = tl.load(logits_row_ptr + y).cast(tl.float32)
for i in range(0, n_cols, BLOCK_SIZE):
offs = i + tl.arange(0, BLOCK_SIZE)
x = tl.load(
logits_row_ptr + offs,
mask=offs < n_cols,
other=float("-inf"),
eviction_policy="evict_first",
).cast(tl.float32)
block_max = tl.max(x)
m_new = tl.maximum(m, block_max)
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(x - m_new))
m = m_new
lse = m + tl.log(d)
loss = lse - x_y
if reduction == "mean":
loss = loss * inv_n
tl.store(loss_ptr + row_i64, loss)


@triton.jit
def liger_cross_entropy_backward_kernel_no_weight(
X_ptr,
Expand Down
Loading
Loading