Skip to content

Commit

Permalink
Add meta implementation for _efficientzerotensor (pytorch#88936)
Browse files Browse the repository at this point in the history
`_efficientzerotensor` is used in several backwards formulas, so its
lack of meta implementation makes those functions untracable.

Pull Request resolved: pytorch#88936
Approved by: https://github.com/anjali411
  • Loading branch information
peterbell10 authored and pytorchmergebot committed Nov 28, 2022
1 parent 69a8c92 commit 2e0cd7c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
13 changes: 13 additions & 0 deletions aten/src/ATen/native/TensorFactories.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,19 @@ Tensor _efficientzerotensor(IntArrayRef size,
return out;
}

Tensor _efficientzerotensor_meta(IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
auto device_ = device_or_default(device);
auto allocator = at::native::ZeroTensorAllocator(device_);
auto dtype_ = dtype_or_default(dtype);
auto zero_ks = at::DispatchKeySet(c10::DispatchKey::Meta) | at::DispatchKeySet(c10::DispatchKey::ZeroTensor);
auto out = at::detail::empty_generic(size, &allocator, zero_ks, dtype_, c10::nullopt);
return out;
}

Tensor& zeros_sparse_out(IntArrayRef size, Tensor& result) {
result.sparse_resize_and_clear_(size, size.size(), 0.);
return result;
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5772,6 +5772,7 @@
dispatch:
CPU: _efficientzerotensor
CUDA: _efficientzerotensor_cuda
Meta: _efficientzerotensor_meta
autogen: _efficientzerotensor.out

- func: zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
Expand Down
2 changes: 0 additions & 2 deletions test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def process(device_type):
"scatter_reduce.sum": {f16},
"scatter_reduce.prod": {f16, f32, f64},
"segment_reduce.lengths": {f16, f32, f64},
"sgn": {f16, f32, f64},
"sparse.sampled_addmm": {f32, f64},
"stft": {f32, f64},
"svd_lowrank": {f32, f64},
Expand Down Expand Up @@ -332,7 +331,6 @@ def process(device_type):
"round.decimals_3": {f16},
"scatter_reduce.prod": {f16, f32, f64},
"segment_reduce.lengths": {f16, f32, f64},
"sgn": {f16, f32, f64},
"sparse.sampled_addmm": {f32, f64},
"stft": {f32, f64},
"svd_lowrank": {f32, f64},
Expand Down

0 comments on commit 2e0cd7c

Please sign in to comment.