Skip to content

Commit 923242e

Browse files
authored
Fix dynamic shape for shared embedding (#1946)
init
1 parent f3ff2e5 commit 923242e

File tree

1 file changed

+2
-10
lines changed

1 file changed

+2
-10
lines changed

torchao/experimental/ops/embedding_xbit/op_embedding_xbit-impl.h

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -265,16 +265,8 @@ Tensor shared_embedding_out_cpu(
265265

266266
int num_out = indices.size(0);
267267

268-
#ifdef USE_ATEN
269-
TORCHAO_CHECK(out.dtype() == torch::kFloat32, "out must be float32");
270-
out.resize_({num_out, k});
271-
#endif // USE_ATEN
272-
273-
#ifdef USE_EXECUTORCH
274-
TORCHAO_CHECK(out.dim() == 2, "out must be 2D");
275-
TORCHAO_CHECK(out.size(0) == num_out, "out shape is incorrect");
276-
TORCHAO_CHECK(out.size(1) == k, "out shape is incorrect");
277-
#endif // USE_EXECUTORCH
268+
// Explicit cast from int64_t to int is required for Executorch
269+
TORCHAO_RESIZE_TENSOR(out, {(int)num_out, (int)k});
278270

279271
const int32_t* index32_ptr = nullptr;
280272
const int64_t* index64_ptr = nullptr;

0 commit comments

Comments
 (0)