-
Notifications
You must be signed in to change notification settings - Fork 551
[NPU]: Add NPU support for the embedding #1028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+220
−0
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,214 @@ | ||
| import torch | ||
| import triton | ||
| import triton.language as tl | ||
|
|
||
| from liger_kernel.ops.backends._ascend.ub_manager import compute_default_tiling_strategy | ||
| from liger_kernel.ops.utils import ensure_contiguous | ||
| from liger_kernel.ops.utils import get_npu_core_count | ||
|
|
||
|
|
||
| @triton.jit | ||
| def embedding_forward_kernel( | ||
| embeddings_ptr, | ||
| indices_ptr, | ||
| output_ptr, | ||
| n_elements, | ||
| embedding_dim: tl.constexpr, | ||
| BLOCK_SIZE_M: tl.constexpr, | ||
| BLOCK_SIZE_N: tl.constexpr, | ||
| NUM_STAGES: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(0) | ||
| num_progs = tl.num_programs(0) | ||
|
|
||
| grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M) | ||
| grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N) | ||
| total_2d_blocks = grid_m * grid_n | ||
|
|
||
| for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES): | ||
| block_m = block_idx // grid_n | ||
| block_n = block_idx % grid_n | ||
|
|
||
| start_m = block_m * BLOCK_SIZE_M | ||
| start_n = block_n * BLOCK_SIZE_N | ||
|
|
||
| offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) | ||
| mask_m = offsets_m < n_elements | ||
|
|
||
| indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) | ||
|
|
||
| offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) | ||
| mask_n = offsets_n < embedding_dim | ||
|
|
||
| block_mask = mask_m[:, None] & mask_n[None, :] | ||
|
|
||
| embedding_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] | ||
| embeddings = tl.load( | ||
| embeddings_ptr + embedding_offsets, | ||
| mask=block_mask, | ||
| other=0.0, | ||
| ) | ||
|
|
||
| output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :] | ||
| tl.store( | ||
| output_ptr + output_offsets, | ||
| embeddings, | ||
| mask=block_mask, | ||
| ) | ||
|
|
||
|
|
||
| @triton.jit | ||
| def embedding_backward_kernel( | ||
| grad_output_ptr, | ||
| grad_weight_ptr, | ||
| indices_ptr, | ||
| n_elements, | ||
| embedding_dim: tl.constexpr, | ||
| BLOCK_SIZE_M: tl.constexpr, | ||
| BLOCK_SIZE_N: tl.constexpr, | ||
| NUM_STAGES: tl.constexpr, | ||
| ): | ||
| pid = tl.program_id(0) | ||
| num_progs = tl.num_programs(0) | ||
|
|
||
| grid_m = tl.cdiv(n_elements, BLOCK_SIZE_M) | ||
| grid_n = tl.cdiv(embedding_dim, BLOCK_SIZE_N) | ||
| total_2d_blocks = grid_m * grid_n | ||
|
|
||
| for block_idx in tl.range(pid, total_2d_blocks, num_progs, num_stages=NUM_STAGES): | ||
| block_m = block_idx // grid_n | ||
| block_n = block_idx % grid_n | ||
|
|
||
| start_m = block_m * BLOCK_SIZE_M | ||
| start_n = block_n * BLOCK_SIZE_N | ||
|
|
||
| offsets_m = start_m + tl.arange(0, BLOCK_SIZE_M) | ||
| mask_m = offsets_m < n_elements | ||
|
|
||
| indices = tl.load(indices_ptr + offsets_m, mask=mask_m, other=0) | ||
|
|
||
| offsets_n = start_n + tl.arange(0, BLOCK_SIZE_N) | ||
| mask_n = offsets_n < embedding_dim | ||
|
|
||
| block_mask = mask_m[:, None] & mask_n[None, :] | ||
|
|
||
| grad_output_offsets = offsets_m[:, None] * embedding_dim + offsets_n[None, :] | ||
| grad_output = tl.load( | ||
| grad_output_ptr + grad_output_offsets, | ||
| mask=block_mask, | ||
| other=0.0, | ||
| ) | ||
|
|
||
| grad_weight_offsets = indices[:, None] * embedding_dim + offsets_n[None, :] | ||
| tl.atomic_add( | ||
| grad_weight_ptr + grad_weight_offsets, | ||
| grad_output, | ||
| mask=block_mask, | ||
| ) | ||
|
|
||
|
|
||
| def get_optimal_block_size(total_elements, dtype_size, BLOCK_SIZE_N: tl.constexpr): | ||
| # 1. Set Memory Multiplier | ||
| # 3.0 are empirical values based on 910B UB (192KB) | ||
| # embedding_offsets, embedding_offsets : BLOCK_SIZE_N * BLOCK_SIZE_M (total 2 * BLOCK_SIZE_N * BLOCK_SIZE_M) | ||
| # Reserve a unit of space for the remaining one-dimensional ub to occupy. | ||
| # A conservative estimate of the total space occupation is 3 * BLOCK_SIZE_N * BLOCK_SIZE_M | ||
| multiplier = 3.0 | ||
|
|
||
| # 2. Call calculation function | ||
| # Treat input as 1D (total_elements,), only tiling on dim 0 | ||
| tile_shapes = compute_default_tiling_strategy( | ||
| safety_margin=0.9, | ||
| dtype_size=dtype_size, | ||
| memory_multiplier=multiplier, | ||
| shapes=((total_elements, BLOCK_SIZE_N),), | ||
| tiling_dims=(0,), | ||
| ) | ||
|
Comment on lines
+120
to
+126
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. dtype_size should be
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Modified |
||
|
|
||
| # 3. Parse result | ||
| if tile_shapes and len(tile_shapes) > 0: | ||
| block_size = tile_shapes[0][0] | ||
| return block_size | ||
| else: | ||
| return triton.next_power_of_2(min(128, total_elements)) | ||
|
|
||
|
|
||
| def embedding_forward(embeddings, indices): | ||
| ori_shape = indices.shape | ||
| indices = indices.view(-1) | ||
|
|
||
| n_elements = indices.numel() | ||
| embedding_dim = embeddings.shape[1] | ||
| output = torch.empty( | ||
| indices.shape[0], | ||
| embeddings.shape[1], | ||
| device=indices.device, | ||
| dtype=embeddings.dtype, | ||
| ) | ||
|
|
||
| # Due to the involvement of two-dimensional partitioning, | ||
| # the sizes of block_m and block_n in the ub space will influence each other. | ||
| # Considering that embedding_dim is usually relatively smaller in most cases, | ||
| # a value is first assigned to block_n, and then the largest possible block_m is used. | ||
| BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) | ||
| BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N) | ||
| num_cores = get_npu_core_count() | ||
| total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N) | ||
| grid = min(num_cores, total_blocks) | ||
|
|
||
| embedding_forward_kernel[(grid,)]( | ||
| embeddings, | ||
| indices, | ||
| output, | ||
| n_elements, | ||
| embedding_dim=embedding_dim, | ||
| BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
| BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
| NUM_STAGES=3, | ||
| ) | ||
|
|
||
| return output.view(*ori_shape, -1) | ||
|
|
||
|
|
||
| def embedding_backward(embeddings, indices, grad_output): | ||
| grad_output = grad_output.contiguous().view(-1, embeddings.shape[1]) | ||
|
|
||
| grad_weight = torch.zeros_like(embeddings) | ||
|
|
||
| n_elements = indices.numel() | ||
| embedding_dim = embeddings.shape[1] | ||
| BLOCK_SIZE_N = triton.next_power_of_2(min(128, embedding_dim)) | ||
| BLOCK_SIZE_M = get_optimal_block_size(n_elements, embeddings.element_size(), BLOCK_SIZE_N) | ||
| num_cores = get_npu_core_count() | ||
| total_blocks = triton.cdiv(n_elements, BLOCK_SIZE_M) * triton.cdiv(embedding_dim, BLOCK_SIZE_N) | ||
| grid = min(num_cores, total_blocks) | ||
|
|
||
| embedding_backward_kernel[(grid,)]( | ||
| grad_output, | ||
| grad_weight, | ||
| indices, | ||
| n_elements, | ||
| embedding_dim=embedding_dim, | ||
| BLOCK_SIZE_M=BLOCK_SIZE_M, | ||
| BLOCK_SIZE_N=BLOCK_SIZE_N, | ||
| NUM_STAGES=3, | ||
| ) | ||
|
|
||
| return grad_weight | ||
|
|
||
|
|
||
| class LigerEmbeddingFunction(torch.autograd.Function): | ||
| @staticmethod | ||
| @ensure_contiguous | ||
| def forward(ctx, embeddings: torch.Tensor, indices: torch.Tensor): | ||
| output = embedding_forward(embeddings, indices) | ||
| ctx.save_for_backward(indices, embeddings) | ||
| return output | ||
|
|
||
| @staticmethod | ||
| @ensure_contiguous | ||
| def backward(ctx, grad_output: torch.Tensor): | ||
| indices, embeddings = ctx.saved_tensors | ||
| grad_weight = embedding_backward(embeddings, indices, grad_output) | ||
|
|
||
| return grad_weight, None | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the original implementation with 2 block sizes for tile shape is more readable and more efficient.
persistant grid loop is fine, but the way this kernel loading embedding seems to be uncoalesced at some point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For instance, there will be some
dim_idxnot consecutive ifBLOCK_SIZEis not multiple ofembedding_dim. It will make the secondtl.loadtrying to access different rows within a warp, as well as the last store.Make these offsets created with 2d block size is more readable and efficient since we can avoid the uncoalesced access mentioned above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have changed it to 2D block. After testing, it has indeed shown much better performance. The issues mentioned below have also been fixed. Could you please review it for me again?