Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/liger_kernel/ops/backends/_ascend/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
If __all__ is not defined, all public symbols will be auto-discovered.
"""

from liger_kernel.ops.backends._ascend.ops.embedding import LigerEmbeddingFunction
from liger_kernel.ops.backends._ascend.ops.embedding import embedding_backward
from liger_kernel.ops.backends._ascend.ops.embedding import embedding_forward
from liger_kernel.ops.backends._ascend.ops.geglu import LigerGELUMulFunction
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_backward
from liger_kernel.ops.backends._ascend.ops.geglu import geglu_forward
Expand All @@ -31,6 +34,9 @@
from liger_kernel.ops.backends._ascend.ops.tvd import tvd_backward_triton

__all__ = [
"LigerEmbeddingFunction",
"embedding_forward",
"embedding_backward",
"LigerGELUMulFunction",
"geglu_forward",
"geglu_backward",
Expand Down
214 changes: 214 additions & 0 deletions src/liger_kernel/ops/backends/_ascend/ops/embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
import torch
import triton
import triton.language as tl

from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy
from liger_kernel.ops.utils import ensure_contiguous
from liger_kernel.ops.utils import get_npu_core_count


@triton.jit
def embedding_forward_kernel(
embeddings_ptr,
indices_ptr,
output_ptr,
n_elements,
embedding_dim: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
Comment on lines +10 to +20

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the original implementation with 2 block sizes for tile shape is more readable and more efficient.

persistant grid loop is fine, but the way this kernel loading embedding seems to be uncoalesced at some point.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For instance, there will be some dim_idx not consecutive if BLOCK_SIZE is not multiple of embedding_dim. It will make the second tl.load trying to access different rows within a warp, as well as the last store.

Make these offsets created with 2d block size is more readable and efficient since we can avoid the uncoalesced access mentioned above.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have changed it to 2D block. After testing, it has indeed shown much better performance. The issues mentioned below have also been fixed. Could you please review it for me again?

pid = tl.program_id(0)
num_progs = tl.num_programs(0)

grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
total_2d_blocks = grid_m * grid_n

for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES):
block_m = block_idx // grid_n
block_n = block_idx % grid_n

start_m = block_m * BLOCK_SIZE_M
start_n = block_n * BLOCK_SIZE_N

offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < n_elements

indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)

offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
mask_n = offsets_n < embedding_dim

block_mask = mask_m[:, None] & mask_n[None, :]

embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
embeddings = tl.load(
embeddings_ptr + embedding_offsets,
mask=block_mask,
other=0.0,
)

output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
tl.store(
output_ptr + output_offsets,
embeddings,
mask=block_mask,
)


@triton.jit
def embedding_backward_kernel(
grad_output_ptr,
grad_weight_ptr,
indices_ptr,
n_elements,
embedding_dim: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
NUM_STAGES: tl.constexpr,
):
pid = tl.program_id(0)
num_progs = tl.num_programs(0)

grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M)
grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N)
total_2d_blocks = grid_m * grid_n

for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES):
block_m = block_idx // grid_n
block_n = block_idx % grid_n

start_m = block_m * BLOCK_SIZE_M
start_n = block_n * BLOCK_SIZE_N

offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M)
mask_m = offsets_m < n_elements

indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0)

offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N)
mask_n = offsets_n < embedding_dim

block_mask = mask_m[:, None] & mask_n[None, :]

grad_output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :]
grad_output = tl.load(
grad_output_ptr + grad_output_offsets,
mask=block_mask,
other=0.0,
)

grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :]
tl.atomic_add(
grad_weight_ptr + grad_weight_offsets,
grad_output,
mask=block_mask,
)


def get_optimal_block_size(total_elements, dtype_size, BLOCK_SIZE_N: tl.constexpr):
# 1. Set Memory Multiplier
# 3.0 are empirical values based on 910B UB (192KB)
# embedding_offsets, embedding_offsets : BLOCK_SIZE_N * BLOCK_SIZE_M (total 2 * BLOCK_SIZE_N * BLOCK_SIZE_M)
# Reserve a unit of space for the remaining one-dimensional ub to occupy.
# A conservative estimate of the total space occupation is 3 * BLOCK_SIZE_N * BLOCK_SIZE_M
multiplier = 3.0

# 2. Call calculation function
# Treat input as 1D (total_elements,), only tiling on dim 0
tile_shapes = compute_default_tiling_strategy(
safety_margin=0.9,
dtype_size=dtype_size,
memory_multiplier=multiplier,
shapes=((total_elements, BLOCK_SIZE_N),),
tiling_dims=(0,),
)
Comment on lines +120 to +126

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dtype_size should be embedding.dtype?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified


# 3. Parse result
if tile_shapes and len(tile_shapes) > 0:
block_size = tile_shapes[0][0]
return block_size
else:
return triton.next_power_of_2(min(128, total_elements))


def embedding_forward(embeddings, indices):
ori_shape = indices.shape
indices = indices.view(-1)

n_elements = indices.numel()
embedding_dim = embeddings.shape[1]
output = torch.empty(
indices.shape[0],
embeddings.shape[1],
device=indices.device,
dtype=embeddings.dtype,
)

# Due to the involvement of two-dimensional partitioning,
# the sizes of block_m and block_n in the ub space will influence each other.
# Considering that embedding_dim is usually relatively smaller in most cases,
# a value is first assigned to block_n, and then the largest possible block_m is used.
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
num_cores = get_npu_core_count()
total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
grid = min(num_cores, total_blocks)

embedding_forward_kernel[(grid,)](
embeddings,
indices,
output,
n_elements,
embedding_dim=embedding_dim,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
NUM_STAGES=3,
)

return output.view(*ori_shape, -1)


def embedding_backward(embeddings, indices, grad_output):
grad_output = grad_output.contiguous().view(-1, embeddings.shape[1])

grad_weight = torch.zeros_like(embeddings)

n_elements = indices.numel()
embedding_dim = embeddings.shape[1]
BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim))
BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N)
num_cores = get_npu_core_count()
total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N)
grid = min(num_cores, total_blocks)

embedding_backward_kernel[(grid,)](
grad_output,
grad_weight,
indices,
n_elements,
embedding_dim=embedding_dim,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
NUM_STAGES=3,
)

return grad_weight


class LigerEmbeddingFunction(torch.autograd.Function):
@staticmethod
@ensure_contiguous
def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor):
output = embedding_forward(embeddings, indices)
ctx.save_for_backward(indices, embeddings)
return output

@staticmethod
@ensure_contiguous
def backward(ctx, grad_output: torch.Tensor):
indices, embeddings = ctx.saved_tensors
grad_weight = embedding_backward(embeddings, indices, grad_output)

return grad_weight, None
Loading