@@ -413,20 +413,28 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
413
413
xla::XlaComputation xla_computation =
414
414
GetValueOrThrow (b.Build (/* remove_dynamic_dimensions=*/ false ));
415
415
416
- std::vector<torch::lazy::BackendDataPtr> parameters_data;
417
- parameters_data.push_back (
416
+ std::vector<XLATensorPtr> tensors{XLATensor::Create (
418
417
torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
419
- bridge::GetDefaultDevice ()->toString (), std::move (shape)));
418
+ bridge::GetDefaultDevice ()->toString (), std::move (shape)))};
419
+ std::vector<std::vector<int64_t >> denormalized_tile_assignments;
420
+ for (auto tensor : tensors) {
421
+ auto sharding_spec = tensor->sharding_spec ();
422
+ if (sharding_spec) {
423
+ denormalized_tile_assignments.push_back (
424
+ sharding_spec->sharding .GetDenormalizedTileAssignment ());
425
+ }
426
+ }
420
427
421
428
std::vector<torch_xla::runtime::ComputationClient::CompileInstance> instances;
422
- instances.push_back ({std::move (xla_computation),
423
- bridge::GetDefaultDevice ()->toString (),
424
- {bridge::GetDefaultDevice ()->toString ()},
425
- &shape,
426
- /* should_wrap_parameter=*/ false ,
427
- /* is_sharded=*/ true ,
428
- /* allow_spmd_sharding_propagation_to_output=*/ true ,
429
- /* parameters_data=*/ parameters_data});
429
+ instances.push_back (
430
+ {std::move (xla_computation),
431
+ bridge::GetDefaultDevice ()->toString (),
432
+ {bridge::GetDefaultDevice ()->toString ()},
433
+ &shape,
434
+ /* should_wrap_parameter=*/ false ,
435
+ /* is_sharded=*/ true ,
436
+ /* allow_spmd_sharding_propagation_to_output=*/ true ,
437
+ /* denormalized_tile_assignments=*/ denormalized_tile_assignments});
430
438
431
439
std::vector<
432
440
std::shared_ptr<torch_xla::runtime::ComputationClient::Computation>>
@@ -437,9 +445,6 @@ TEST_F(XLAShardingTest, PrepareOutputShardingPropagation) {
437
445
" add" , std::move (computations[0 ]->move_computation ()));
438
446
439
447
// Prepare output sharding propagation, expect a sharded output placeholder.
440
- std::vector<XLATensorPtr> tensors{XLATensor::Create (
441
- torch_xla::runtime::GetComputationClientOrDie ()->CreateDataPlaceholder (
442
- bridge::GetDefaultDevice ()->toString (), std::move (shape)))};
443
448
std::vector<torch::lazy::BackendDataPtr> data_placeholders;
444
449
std::vector<XLATensor::ShardingSpecPtr> sharding_specs;
445
450
ShardingUtil::PrepareOutputShardingPropagation (
0 commit comments