File tree Expand file tree Collapse file tree 1 file changed +2
-10
lines changed
torchao/experimental/ops/embedding_xbit Expand file tree Collapse file tree 1 file changed +2
-10
lines changed Original file line number Diff line number Diff line change @@ -265,16 +265,8 @@ Tensor shared_embedding_out_cpu(
265265
266266 int num_out = indices.size (0 );
267267
268- #ifdef USE_ATEN
269- TORCHAO_CHECK (out.dtype () == torch::kFloat32 , " out must be float32" );
270- out.resize_ ({num_out, k});
271- #endif // USE_ATEN
272-
273- #ifdef USE_EXECUTORCH
274- TORCHAO_CHECK (out.dim () == 2 , " out must be 2D" );
275- TORCHAO_CHECK (out.size (0 ) == num_out, " out shape is incorrect" );
276- TORCHAO_CHECK (out.size (1 ) == k, " out shape is incorrect" );
277- #endif // USE_EXECUTORCH
268+ // Explicit cast from int64_t to int is required for Executorch
269+ TORCHAO_RESIZE_TENSOR (out, {(int )num_out, (int )k});
278270
279271 const int32_t * index32_ptr = nullptr ;
280272 const int64_t * index64_ptr = nullptr ;
You can’t perform that action at this time.
0 commit comments