We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f3ff2e5 commit 923242eCopy full SHA for 923242e
torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h
@@ -265,16 +265,8 @@ Tensor shared_embedding_out_cpu(
265
266
int num_out = indices.size(0);
267
268
-#ifdef USE_ATEN
269
- TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
270
- out.resize_({num_out, k});
271
-#endif // USE_ATEN
272
-
273
-#ifdef USE_EXECUTORCH
274
- TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
275
- TORCHAO_CHECK(out.size(0) == num_out, "out shape is incorrect");
276
- TORCHAO_CHECK(out.size(1) == k, "out shape is incorrect");
277
-#endif // USE_EXECUTORCH
+ // Explicit cast from int64_t to int is required for Executorch
+ TORCHAO_RESIZE_TENSOR(out, {(int)num_out, (int)k});
278
279
const int32_t* index32_ptr = nullptr;
280
const int64_t* index64_ptr = nullptr;
0 commit comments