@@ -186,8 +186,7 @@ Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit)
186
186
}
187
187
188
188
Tensor expand_symint_batching_rule (const Tensor& self, SymIntArrayRef psize, bool implicit) {
189
- // TODO: properly support this
190
- return expand_batching_rule (self, asIntArrayRefSlow (psize), implicit);
189
+ return self.expand (asIntArrayRefSlow (psize), implicit);
191
190
}
192
191
193
192
std::vector<Tensor> chunk_batching_rule (const Tensor& self, int64_t chunks, int64_t dim) {
@@ -470,8 +469,7 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
470
469
}
471
470
472
471
Tensor view_symint_batching_rule (const Tensor& self, c10::SymIntArrayRef size) {
473
- // TODO: properly support this
474
- return view_batching_rule (self, asIntArrayRefSlow (size));
472
+ return self.view (asIntArrayRefSlow (size));
475
473
}
476
474
477
475
Tensor view_as_complex_batching_rule (const Tensor& self) {
@@ -1011,7 +1009,6 @@ Tensor new_empty_symint_batching_rule(
1011
1009
c10::optional<Layout> layout,
1012
1010
c10::optional<Device> device,
1013
1011
c10::optional<bool > pin_memory) {
1014
- // TODO: properly support this
1015
1012
return new_empty_batching_rule (self, asIntArrayRefSlow (size), dtype, layout, device, pin_memory);
1016
1013
}
1017
1014
@@ -1112,7 +1109,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
1112
1109
m.impl (" tensor_split.sections" , tensor_split_sections_batching_rule);
1113
1110
m.impl (" tensor_split.indices" , tensor_split_indices_batching_rule);
1114
1111
m.impl (" diagonal" , diagonal_batching_rule);
1115
- m.impl (" expand" , expand_symint_batching_rule);
1112
+ m.impl (" expand" , expand_batching_rule);
1113
+ m.impl (" expand.SymInt" , expand_symint_batching_rule);
1116
1114
m.impl (" expand_as" , native::expand_as); // composite wrt autograd
1117
1115
m.impl (" movedim.intlist" , movedim_batching_rule);
1118
1116
m.impl (" movedim.int" , static_cast <Tensor (*)(const Tensor&,int64_t ,int64_t )>(native::movedim)); // composite wrt autograd
@@ -1140,7 +1138,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
1140
1138
m.impl (" unbind.int" , unbind_batching_rule);
1141
1139
m.impl (" unfold" , unfold_batching_rule);
1142
1140
m.impl (" unsqueeze" , unsqueeze_batching_rule);
1143
- m.impl (" view" , view_symint_batching_rule);
1141
+ m.impl (" view" , view_batching_rule);
1142
+ m.impl (" view.SymInt" , view_symint_batching_rule);
1144
1143
m.impl (" view_as" , native::view_as); // composite wrt autograd
1145
1144
1146
1145
// clamp operations
@@ -1278,7 +1277,8 @@ TORCH_LIBRARY_IMPL(aten, Batched, m) {
1278
1277
m.impl (" diagonal_backward" , diagonal_backward_batching_rule);
1279
1278
1280
1279
// Tensor.new_* operators
1281
- m.impl (" new_empty" , new_empty_symint_batching_rule);
1280
+ m.impl (" new_empty" , new_empty_batching_rule);
1281
+ m.impl (" new_empty.SymInt" , new_empty_symint_batching_rule);
1282
1282
m.impl (" new_empty_strided" , new_empty_strided_batching_rule);
1283
1283
m.impl (" new_zeros" , new_zeros_batching_rule);
1284
1284
0 commit comments