Skip to content

Commit

Permalink
[jit][shape_prop] Fix jit registration of unpack_sizes ops for prepac…
Browse files Browse the repository at this point in the history
…ked (pytorch#66737)

Summary: Pull Request resolved: pytorch#66737

Test Plan: Imported from OSS

Reviewed By: eellison

Differential Revision: D31703587

Pulled By: IvanKobzarev

fbshipit-source-id: ccebe5ffc4fa959e3fa63afab1058d94e9df9dd9
  • Loading branch information
IvanKobzarev authored and facebook-github-bot committed Nov 1, 2021
1 parent 251278d commit e80cb08
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 12 deletions.
6 changes: 3 additions & 3 deletions aten/src/ATen/native/xnnpack/Convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,18 +426,18 @@ Tensor conv2d_clamp_run(
}

// Op is registered to have Any argument as we plan to reuse it for prepacked conv2d of other backends
std::tuple<IntArrayRef, c10::optional<IntArrayRef>, IntArrayRef, IntArrayRef, IntArrayRef, int64_t>
IValue
unpack_prepacked_sizes_conv2d(const IValue& ivalue) {
auto op_context = ivalue.toCustomClass<xnnpack::Conv2dOpContext>();
const auto tuple = op_context->unpack();
const auto& bias = std::get<1>(tuple);
return std::make_tuple(
return IValue(std::make_tuple(
std::get<0>(tuple).sizes(),
(bias && bias->defined()) ? c10::optional<IntArrayRef>(bias->sizes()) : c10::nullopt,
std::get<2>(tuple),
std::get<3>(tuple),
std::get<4>(tuple),
std::get<5>(tuple));
std::get<5>(tuple)));
}

Tensor conv2d_transpose_clamp_run(
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/xnnpack/Convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ Tensor conv2d_clamp_run(
const Tensor& input,
const c10::intrusive_ptr<xnnpack::Conv2dOpContext>& op_context);

std::tuple<IntArrayRef, c10::optional<IntArrayRef>, IntArrayRef, IntArrayRef, IntArrayRef, int64_t>
IValue
unpack_prepacked_sizes_conv2d(const IValue& ivalue);

Tensor conv2d_transpose_clamp_run(
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/xnnpack/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ Tensor linear_clamp_run(
return op_context->run(input);
}

std::tuple<IntArrayRef, c10::optional<IntArrayRef>>
IValue
unpack_prepacked_sizes_linear(const IValue& ivalue) {
auto op_context = ivalue.toCustomClass<xnnpack::LinearOpContext>();
const auto tuple = op_context->unpack();
const auto& bias = std::get<1>(tuple);
return std::make_tuple(
return IValue(std::make_tuple(
std::get<0>(tuple).sizes(),
(bias && bias->defined()) ? c10::optional<IntArrayRef>(bias->sizes()) : c10::nullopt);
(bias && bias->defined()) ? c10::optional<IntArrayRef>(bias->sizes()) : c10::nullopt));
}

} // namespace linear
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/xnnpack/Linear.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ c10::intrusive_ptr<xnnpack::LinearOpContext> createLinearClampPrePackOpContext(

Tensor linear_clamp_run(const Tensor& input, const c10::intrusive_ptr<xnnpack::LinearOpContext>& op_context);

std::tuple<IntArrayRef, c10::optional<IntArrayRef>>
IValue
unpack_prepacked_sizes_linear(const IValue& ivalue);

ContextLinear create(
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/native/xnnpack/RegisterOpContextClass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ TORCH_LIBRARY(xnnpack, m) {

// Registration using the TORCH_LIBRARY def gives dispatching errors when there is no tensor input
TORCH_LIBRARY(prepacked, m) {
m.def(TORCH_SELECTIVE_SCHEMA("prepacked::unpack_prepacked_sizes_conv2d(Any W_prepack) -> (int[], int[]?, int[], int[], int[], int)"), [](const IValue& inp) { return internal::convolution2d::unpack_prepacked_sizes_conv2d(inp);});
m.def(TORCH_SELECTIVE_SCHEMA("prepacked::unpack_prepacked_sizes_linear(Any W_prepack) -> (int[], int[]?)"), [](const IValue& inp) { return internal::linear::unpack_prepacked_sizes_linear(inp);});
m.def(TORCH_SELECTIVE_SCHEMA("prepacked::unpack_prepacked_sizes_conv2d(Any W_prepack) -> (Any)"), [](const IValue& inp) { return internal::convolution2d::unpack_prepacked_sizes_conv2d(inp);});
m.def(TORCH_SELECTIVE_SCHEMA("prepacked::unpack_prepacked_sizes_linear(Any W_prepack) -> (Any)"), [](const IValue& inp) { return internal::linear::unpack_prepacked_sizes_linear(inp);});
m.def(TORCH_SELECTIVE_SCHEMA("prepacked::linear_clamp_prepack(Tensor W, Tensor? B=None, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.LinearOpContext"));
m.def(TORCH_SELECTIVE_SCHEMA("prepacked::linear_clamp_run(Tensor X, __torch__.torch.classes.xnnpack.LinearOpContext W_prepack) -> Tensor Y"));
m.def(TORCH_SELECTIVE_SCHEMA("prepacked::conv2d_clamp_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, Scalar? output_min=None, Scalar? output_max=None) -> __torch__.torch.classes.xnnpack.Conv2dOpContext"));
Expand Down
2 changes: 2 additions & 0 deletions test/backward_compatibility/check_backward_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@
("aten::grid_sampler_2d_backward", datetime.date(2021, 10, 21)),
("prim::TensorExprDynamicGuard", datetime.date(2021, 11, 20)),
("caffe2::", datetime.date(2021, 10, 23)),
("prepacked::unpack_prepacked_sizes_conv2d", datetime.date(9999, 1, 1)),
("prepacked::unpack_prepacked_sizes_linear", datetime.date(9999, 1, 1)),
]

ALLOW_LIST_COMPILED = [
Expand Down
4 changes: 2 additions & 2 deletions torch/csrc/jit/runtime/symbolic_shape_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,12 @@ const std::string shape_compute_functions =
R"(
def prepacked_conv2d_clamp_run(input: List[int], conv2dOpContext: Any):
assert isinstance(conv2dOpContext, __torch__.torch.classes.xnnpack.Conv2dOpContext)
(weight, bias, stride, padding, dilation, groups) = ops.prepacked.unpack_prepacked_sizes_conv2d(conv2dOpContext)
(weight, bias, stride, padding, dilation, groups) = unchecked_cast(Tuple[List[int], Optional[List[int]], List[int], List[int], List[int], int], ops.prepacked.unpack_prepacked_sizes_conv2d(conv2dOpContext))
return conv2d(input, weight, bias, stride, padding, dilation, groups)
def prepacked_linear_clamp_run(input: List[int], linearOpContext: Any):
assert isinstance(linearOpContext, __torch__.torch.classes.xnnpack.LinearOpContext)
(weight, bias) = ops.prepacked.unpack_prepacked_sizes_linear(linearOpContext)
(weight, bias) = unchecked_cast(Tuple[List[int], Optional[List[int]]], ops.prepacked.unpack_prepacked_sizes_linear(linearOpContext))
return linear(input, weight, bias)
)"
#endif
Expand Down

0 comments on commit e80cb08

Please sign in to comment.