Skip to content

Commit 87cfbdf

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
implementation of fbgemm op - permute_multi_embedding (#2738)
Summary: Pull Request resolved: #2738 X-link: meta-pytorch/torchrec#2120 # context * current we have a working function `permute_pooled_embs_auto_grad` to do a full permute of KTs, including forward and backward * it has several limitations: a) it has to be a full permute, duplicates are not supported; b) in the main [use case](https://fburl.com/code/89od0rqm) there has to be a torch.concat on the input KTs, which is not very efficient; c) the function output a single KT which requires a split operation * there is some attempt to support duplicated outputs, but the backward doesn't work * this diff is trying to create a new kernel (named `permute_multi_embedding`) to support a multiple-KT to multiple-KT mapping operation with backward support # notes * this diff focuses on the implemenation and test of the operator * performance analysis and benchmark are in the next diff # operator example usage * used in python ``` # test inputs: 3 KTs with batch_size=2048 batch_size = 2048 keys = [["f1", "f2"], ["f3", "f4", "f5"], ["f6"]] lengths = [[96, 256], [512, 128, 768], [1024]] values = [ torch.randn(batch_size, sum(lens), device="cuda", requires_grad=True) for lens in lengths ] # target outputs: 4 KTs with re-arranged keys (features), duplicates are allowed groups = [["f1", "f3"], ["f2"], ["f4", "f1", "f6"], ["f1", "f5"]] # accessorial arguments to the op/kernel permutes, in_lengths, out_lengths = _multi_remap_to_groups( keys, lengths, groups ) # arguments outputs = torch.ops.fbgemm.permute_multi_embedding_internal_testing( values, permutes, in_lengths, out_lengths ) ``` * permutes ``` # each row represents a key (feature) permute move, which consists of the following parameters: # [input_tensor_idx, output_tensor_idx, input_key_idx, output_key_idx, key_length, magic_jump] permutes = tensor( [ [0, 0, 0, 0, 3, 4], # f1 [1, 0, 0, 3, 5, 0], # f3 [0, 1, 3, 0, 4, 0], # f2 [1, 2, 5, 0, 6, 0], # f4 [0, 2, 0, 6, 3, -6], # f1 [2, 2, 0, 9, 8, 0], # f6 [0, 3, 0, 0, 3, -8], # f1 [1, 3, 11, 3, 7, 0], # f5 ] ) ``` # details 1. from the above example usage, we can clearly see that the operatior takes in the following: a) values: List[torch.Tensor], which represents the input KTs b) permutes: torch.Tensor, which contains the permute information, will be explained later c) output_lengths_list: List[int], the lengths of the output tensors (KTs), which is needed to allocate memory on device ahead d) in_lengths: torch.Tensor, lengths of input tensors, which is on device e) out_lengths: torch.Tensor, lengths of output tensors, which is on device 2. the operator returns a list of tensors, which represents the permuted KTs 3. `permute` is the most critical argument in this operator: a) 2-D tensor b) each row represents a key (feature) permute move c) a permute move = [input_tensor_id, output_tensor_id, input_start_idx, output_start_idx, feature_length, jump] d) jump is used in backward when a key (feature) from the input tensor is mapped to multiple places in the output tensors 4. The magic_jump a) It's only used in the backward computation b) it's usually 0, means no jump c) it's non-zero when there is a duplicate in the permute, e.g., the same feature appears more than once in the output d) the `magic_jump` is the next index of the very same feature in the permute sequence with some modifications e) modification-1: `magic_jump` is positive when it's the first of its kind [Start] f) modification-2: `magic_jump` is negative when it's not the first of its kind [Continue] g) modification-3: `magic_jump` is the negative value of the length of the permute sequence when it's the last of its kind. [Stop] Reviewed By: sryap Differential Revision: D57055616 fbshipit-source-id: 16673d3a2eafab93b08d4ff3c43d54366966064a
1 parent bbdd76c commit 87cfbdf

File tree

6 files changed

+574
-0
lines changed

6 files changed

+574
-0
lines changed

fbgemm_gpu/FbgemmGpu.cmake

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,8 @@ set(fbgemm_gpu_sources_static_cpu
455455
codegen/training/backward/embedding_backward_dense_host_cpu.cpp
456456
codegen/utils/embedding_bounds_check_host_cpu.cpp
457457
src/merge_pooled_embedding_ops/merge_pooled_embedding_ops_cpu.cpp
458+
src/permute_multi_embedding_ops/permute_multi_embedding_function.cpp
459+
src/permute_multi_embedding_ops/permute_multi_embedding_ops_cpu.cpp
458460
src/permute_pooled_embedding_ops/permute_pooled_embedding_function.cpp
459461
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_cpu.cpp
460462
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split_cpu.cpp
@@ -547,6 +549,7 @@ if(NOT FBGEMM_CPU_ONLY)
547549
src/metric_ops/metric_ops.cu
548550
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops_split.cu
549551
src/permute_pooled_embedding_ops/permute_pooled_embedding_ops.cu
552+
src/permute_multi_embedding_ops/permute_multi_embedding_ops.cu
550553
src/quantize_ops/quantize_bfloat16.cu
551554
src/quantize_ops/quantize_fp8_rowwise.cu
552555
src/quantize_ops/quantize_fused_8bit_rowwise.cu

fbgemm_gpu/fbgemm_gpu/permute_pooled_embedding_modules.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,25 @@
1919
torch.ops.load_library(
2020
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_cpu"
2121
)
22+
torch.ops.load_library(
23+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_cpu"
24+
)
2225
try:
2326
torch.ops.load_library(
2427
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu"
2528
)
29+
torch.ops.load_library(
30+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu"
31+
)
2632
except OSError:
2733
# This is for forward compatibility (new torch.package + old backend)
2834
# We should be able to remove it after this diff is picked up by all backend
2935
torch.ops.load_library(
3036
"//deeplearning/fbgemm/fbgemm_gpu:permute_pooled_embedding_ops_gpu_cuda"
3137
)
38+
torch.ops.load_library(
39+
"//deeplearning/fbgemm/fbgemm_gpu:permute_multi_embedding_ops_gpu_cuda"
40+
)
3241
except OSError:
3342
pass
3443

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <ATen/ATen.h>
12+
#include <ATen/Parallel.h>
13+
#include <torch/csrc/api/include/torch/types.h>
14+
#include <torch/csrc/autograd/custom_function.h>
15+
16+
#include "fbgemm_gpu/dispatch_macros.h"
17+
#include "fbgemm_gpu/ops_utils.h"
18+
#include "fbgemm_gpu/sparse_ops_utils.h"
19+
20+
namespace fbgemm_gpu {
21+
22+
using Tensor = at::Tensor;
23+
using torch::autograd::AutogradContext;
24+
using torch::autograd::variable_list;
25+
26+
using Tensor = at::Tensor;
27+
using torch::autograd::AutogradContext;
28+
using torch::autograd::variable_list;
29+
30+
class PermuteMultiEmbeddingOp
31+
: public torch::autograd::Function<PermuteMultiEmbeddingOp> {
32+
public:
33+
static variable_list forward(
34+
AutogradContext* ctx,
35+
const at::TensorList& pooled_embs,
36+
const Tensor& permutes,
37+
const Tensor& in_shapes,
38+
const Tensor& out_shapes,
39+
const std::vector<int64_t>& out_lengths);
40+
41+
static variable_list backward(
42+
AutogradContext* ctx,
43+
variable_list grad_output);
44+
};
45+
46+
std::vector<Tensor> permute_multi_embedding_cpu(
47+
const at::TensorList& pooled_embs,
48+
const Tensor& permutes,
49+
const Tensor& in_shapes,
50+
const Tensor& out_shapes,
51+
const std::vector<int64_t>& out_lengths,
52+
const bool& reverse_permute);
53+
54+
std::vector<Tensor> permute_multi_embedding_meta(
55+
const at::TensorList& pooled_embs,
56+
const Tensor& permutes,
57+
const Tensor& in_shapes,
58+
const Tensor& out_shapes,
59+
const std::vector<int64_t>& out_lengths,
60+
const bool& reverse_permute);
61+
62+
std::vector<Tensor> permute_multi_embedding_gpu(
63+
const at::TensorList& pooled_embs,
64+
const Tensor& permutes,
65+
const Tensor& in_shapes,
66+
const Tensor& out_shapes,
67+
const std::vector<int64_t>& out_lengths,
68+
const bool& reverse_permute);
69+
} // namespace fbgemm_gpu
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "fbgemm_gpu/permute_multi_embedding_function.h"
10+
#include <cstdint>
11+
#include <iostream>
12+
13+
namespace fbgemm_gpu {
14+
15+
using Tensor = at::Tensor;
16+
using torch::autograd::AutogradContext;
17+
using torch::autograd::variable_list;
18+
19+
variable_list PermuteMultiEmbeddingOp::forward(
20+
AutogradContext* ctx,
21+
const at::TensorList& pooled_embs,
22+
const Tensor& permutes,
23+
const Tensor& in_shapes,
24+
const Tensor& out_shapes,
25+
const std::vector<int64_t>& out_lengths) {
26+
ctx->saved_data["permutes"] = permutes;
27+
ctx->saved_data["in_shapes"] = in_shapes;
28+
ctx->saved_data["out_shapes"] = out_shapes;
29+
30+
std::vector<int64_t> in_lengths;
31+
in_lengths.reserve(pooled_embs.size());
32+
for (auto i : c10::irange(pooled_embs.size())) {
33+
in_lengths.push_back(pooled_embs[i].size(1));
34+
}
35+
ctx->saved_data["in_lengths"] = in_lengths;
36+
37+
/*
38+
select the correct dispatched (cpu/gpu) forward function
39+
the cpu/gup function needs to be registered in the dispatcher,
40+
e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc.
41+
*/
42+
const auto permute_op =
43+
torch::Dispatcher::singleton()
44+
.findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "")
45+
.typed<decltype(permute_multi_embedding_cpu)>();
46+
47+
return permute_op.call(
48+
pooled_embs, permutes, in_shapes, out_shapes, out_lengths, false);
49+
}
50+
51+
variable_list PermuteMultiEmbeddingOp::backward(
52+
AutogradContext* ctx,
53+
variable_list grad_output) {
54+
const auto permutes = ctx->saved_data["permutes"].toTensor();
55+
const auto in_shapes = ctx->saved_data["in_shapes"].toTensor();
56+
const auto out_shapes = ctx->saved_data["out_shapes"].toTensor();
57+
const auto in_lengths = ctx->saved_data["in_lengths"].toIntVector();
58+
59+
/*
60+
select the correct dispatched (cpu/gpu) backward function
61+
the cpu/gup function needs to be registered in the dispatcher,
62+
e.g., DISPATCH_TO_CPU, DISPATCH_TO_CUDA, etc.
63+
*/
64+
const auto permute_op =
65+
torch::Dispatcher::singleton()
66+
.findSchemaOrThrow("fbgemm::permute_multi_embedding_function", "")
67+
.typed<decltype(permute_multi_embedding_cpu)>();
68+
auto grad_input = permute_op.call(
69+
grad_output, permutes, out_shapes, in_shapes, in_lengths, true);
70+
grad_input.push_back(torch::autograd::Variable()); // permutes
71+
grad_input.push_back(torch::autograd::Variable()); // in_shapes
72+
grad_input.push_back(torch::autograd::Variable()); // out_shapes
73+
grad_input.push_back(torch::autograd::Variable()); // out_lengths
74+
return grad_input;
75+
}
76+
77+
} // namespace fbgemm_gpu

0 commit comments

Comments
 (0)