@@ -446,9 +446,6 @@ function Sharding.disassemble_into_single_device_arrays(
446
446
single_device_shards = XLA. IFRT. disassemble_into_single_device_arrays (x. data, true )
447
447
448
448
padded_size = size (x) .+ get_padding (x)
449
- @show padded_size
450
- @show size (x)
451
- @show get_padding (x)
452
449
453
450
if x. sharding. sharding isa Sharding. HloSharding
454
451
(; hlo_sharding) = x. sharding. sharding
@@ -462,13 +459,44 @@ function Sharding.disassemble_into_single_device_arrays(
462
459
padded_size,
463
460
x. sharding. mesh. logical_device_ids,
464
461
)
465
- return [
462
+
463
+ mapping = [
466
464
slice => ConcreteIFRTArray {T,N} (
467
- XLA. IFRT. AsyncArray (shard, nothing ), length .(slice), Sharding. NoShardInfo ()
465
+ XLA. IFRT. AsyncArray (shard, nothing ),
466
+ map (length, slice),
467
+ Sharding. NoShardInfo (),
468
468
) for
469
469
(slice, shard, device) in zip (array_slices, single_device_shards, all_devices) if
470
470
XLA. is_addressable (device)
471
471
]
472
+
473
+ has_padding (x) || return mapping
474
+
475
+ mapping_unpadded = Vector {eltype(mapping)} (undef, length (mapping))
476
+ for (i, (slice, shard)) in enumerate (mapping)
477
+ chop_ends = map (enumerate (slice)) do (i, idx_range)
478
+ last (idx_range) > size (x, i) && return last (idx_range) - size (x, i)
479
+ return 0
480
+ end
481
+
482
+ if all (iszero, chop_ends)
483
+ mapping_unpadded[i] = mapping[i]
484
+ else
485
+ new_slice = map (zip (slice, chop_ends)) do (idx_range, chop)
486
+ chop == 0 && return idx_range
487
+ return first (idx_range): (last (idx_range) - chop)
488
+ end
489
+ if ! any (iszero ∘ length, new_slice)
490
+ mapping_unpadded[i] =
491
+ Tuple (new_slice) => shard[map (Base. OneTo ∘ length, new_slice)... ]
492
+ end
493
+ end
494
+ end
495
+
496
+ return [
497
+ mapping_unpadded[i] for
498
+ i in 1 : length (mapping_unpadded) if isassigned (mapping_unpadded, i)
499
+ ]
472
500
end
473
501
474
502
# # ConcreteRNG
0 commit comments