Skip to content

Commit

Permalink
Update impl_abstract_pystub to be less boilerplatey
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/pytorch#112851

We've made the following changes:
- The new way to use the API is `m.impl_abstract_pystub(module, context)`.
  Every subsequent m.def of an op inside the TORCH_LIBRARY block gives
  the op the `impl_abstract_pystub`.
- Added a mechanism to determine if an operator was defined in Python or C++.
  Library.define in Python appends the op to a global set, which is analogous
  to what we do for tracking Library.impl.
- If someone does `torch.library.impl_abstract` in Python for an operator, then
  we require that it has an `impl_abstract_pystub` specified and we also check
  that the module in the `impl_abstract_pystub` is the same as the module where
  the call to `torch.library.impl_abstract` exists.
- Unfortunately we can't check the "context" (which is the buck target on
  buck-based systems) because buck sits above us.

bypass-github-export-checks

Reviewed By: ezyang

Differential Revision: D50972148

fbshipit-source-id: 34ab31493d9bccd0d351b463cfb508ba0b05eef4
  • Loading branch information
zou3519 authored and facebook-github-bot committed Nov 7, 2023
1 parent 21d0c95 commit b6bdf04
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ Tensor merge_pooled_embeddings_cpu(
} // namespace fbgemm_gpu

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
#ifdef HAS_IMPL_ABSTRACT_PYSTUB
m.impl_abstract_pystub(
"fbgemm_gpu.sparse_ops",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");
#endif
m.def(
"merge_pooled_embeddings(Tensor[] pooled_embeddings, SymInt uncat_dim_size, Device target_device, SymInt cat_dim=1) -> Tensor");
m.def(
Expand Down
5 changes: 5 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2686,6 +2686,11 @@ Tensor bottom_k_per_row(
} // namespace fbgemm_gpu

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
#ifdef HAS_IMPL_ABSTRACT_PYSTUB
m.impl_abstract_pystub(
"fbgemm_gpu.sparse_ops",
"//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_py");
#endif
m.def(
"permute_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, SymInt? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def(
Expand Down

0 comments on commit b6bdf04

Please sign in to comment.