Commit 87cfbdf
implementation of fbgemm op - permute_multi_embedding (#2738)
Summary:
Pull Request resolved: #2738
X-link: meta-pytorch/torchrec#2120
# context
* current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward
* it has several limitations:
a) it has to be a full permute, duplicates are not supported;
b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient;
c) the function output a single KT which requires a split operation
* there is some attempt to support duplicated outputs, but the backward doesn't work
* this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support
# notes
* this diff focuses on the implemenation and test of the operator
* performance analysis and benchmark are in the next diff
# operator example usage
* used in python
```
# test inputs: 3 KTs with batch_size=2048
batch_size = 2048
keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]]
lengths = [[96, 256], [512, 128, 768], [1024]]
values = [
torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True)
for lens in lengths
]
# target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed
groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]]
# accessorial arguments to the op/kernel
permutes, in_lengths, out_lengths = _multi_remap_to_groups(
keys, lengths, groups
)
# arguments
outputs = torch.ops.fbgemm.permute_multi_embedding_internal_testing(
values, permutes, in_lengths, out_lengths
)
```
* permutes
```
# each row represents a key (feature) permute move, which consists of the following parameters:
# [input_tensor_idx, output_tensor_idx, input_key_idx, output_key_idx, key_length, magic_jump]
permutes = tensor(
[
[0, 0, 0, 0, 3, 4], # f1
[1, 0, 0, 3, 5, 0], # f3
[0, 1, 3, 0, 4, 0], # f2
[1, 2, 5, 0, 6, 0], # f4
[0, 2, 0, 6, 3, -6], # f1
[2, 2, 0, 9, 8, 0], # f6
[0, 3, 0, 0, 3, -8], # f1
[1, 3, 11, 3, 7, 0], # f5
]
)
```
# details
1. from the above example usage, we can clearly see that the operatior takes in the following:
a) values: List[torch.Tensor], which represents the input KTs
b) permutes: torch.Tensor, which contains the permute information, will be explained later
c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead
d) in_lengths: torch.Tensor, lengths of input tensors, which is on device
e) out_lengths: torch.Tensor, lengths of output tensors, which is on device
2. the operator returns a list of tensors, which represents the permuted KTs
3. `permute` is the most critical argument in this operator:
a) 2-D tensor
b) each row represents a key (feature) permute move
c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump]
d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors
4. The magic_jump
a) It's only used in the backward computation
b) it's usually 0, means no jump
c) it's non-zero when there is a duplicate in the permute, e.g., the same feature appears more than once in the output
d) the `magic_jump` is the next index of the very same feature in the permute sequence with some modifications
e) modification-1: `magic_jump` is positive when it's the first of its kind [Start]
f) modification-2: `magic_jump` is negative when it's not the first of its kind [Continue]
g) modification-3: `magic_jump` is the negative value of the length of the permute sequence when it's the last of its kind. [Stop]
Reviewed By: sryap
Differential Revision: D57055616
fbshipit-source-id: 16673d3a2eafab93b08d4ff3c43d54366966064a1 parent bbdd76c commit 87cfbdf
File tree
6 files changed
+574
-0
lines changed- fbgemm_gpu
- fbgemm_gpu
- include/fbgemm_gpu
- src/permute_multi_embedding_ops
6 files changed
+574
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
455 | 455 | | |
456 | 456 | | |
457 | 457 | | |
| 458 | + | |
| 459 | + | |
458 | 460 | | |
459 | 461 | | |
460 | 462 | | |
| |||
547 | 549 | | |
548 | 550 | | |
549 | 551 | | |
| 552 | + | |
550 | 553 | | |
551 | 554 | | |
552 | 555 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
19 | 19 | | |
20 | 20 | | |
21 | 21 | | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
22 | 25 | | |
23 | 26 | | |
24 | 27 | | |
25 | 28 | | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
26 | 32 | | |
27 | 33 | | |
28 | 34 | | |
29 | 35 | | |
30 | 36 | | |
31 | 37 | | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
32 | 41 | | |
33 | 42 | | |
34 | 43 | | |
| |||
Lines changed: 69 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
Lines changed: 77 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
0 commit comments