Skip to content

Commit

Permalink
Fix Use-after-Free in qembeddingbag_byte_prepack_out (pytorch#84750)
Browse files Browse the repository at this point in the history
When FBGEMM is not used (either manually disabled or on platforms such as POWER where it isn't supported at all) the fallback code requests a `data_ptr<float>` on a `Tensor` object returned by `to(ScalarType::Float)` in the same line. This object will be destroyed at the end of the line leading to a dangling pointer.

On some platforms this manifests in wrong results being returned as the memory gets overwritten. On other platforms anything may happen due to this being undefined behavior, although most likely it will just crash or continue to return semi-random results which may even happen to be correct (when the memory is not reused yet)

Fix this by binding the temporary object (or initial object) to a const value reference which extents its lifetime and getting the `data_ptr` from that.

Fixes pytorch#84748

This bug was introduced by a seemingly unrelated change in pytorch#64081 hence ccing @d1jang

Pull Request resolved: pytorch#84750
Approved by: https://github.com/kimishpatel
  • Loading branch information
Flamefire authored and pytorchmergebot committed Nov 23, 2022
1 parent 07dd2fe commit 7594e04
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,10 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
}

#else
const auto weight_data = weight_contig->scalar_type() == at::ScalarType::Half
? weight_contig->to(at::ScalarType::Float).data_ptr<float>()
: weight_contig->data_ptr<float>();
const Tensor& float_weight = weight_contig->scalar_type() == at::ScalarType::Half
? weight_contig->to(at::ScalarType::Float)
: *weight_contig;
const auto weight_data = float_weight.data_ptr<float>();
constexpr float kEpsilon = 1e-8f;
for (auto row : c10::irange(embedding_rows)) {
const float* input_row = weight_data + row * embedding_cols;
Expand Down

0 comments on commit 7594e04

Please sign in to comment.