Skip to content

Commit 92c4775

Browse files
committed
fix: handle padding
1 parent c158e2a commit 92c4775

File tree

1 file changed

+33
-5
lines changed

1 file changed

+33
-5
lines changed

src/Types.jl

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -446,9 +446,6 @@ function Sharding.disassemble_into_single_device_arrays(
446446
single_device_shards = XLA.IFRT.disassemble_into_single_device_arrays(x.data, true)
447447

448448
padded_size = size(x) .+ get_padding(x)
449-
@show padded_size
450-
@show size(x)
451-
@show get_padding(x)
452449

453450
if x.sharding.sharding isa Sharding.HloSharding
454451
(; hlo_sharding) = x.sharding.sharding
@@ -462,13 +459,44 @@ function Sharding.disassemble_into_single_device_arrays(
462459
padded_size,
463460
x.sharding.mesh.logical_device_ids,
464461
)
465-
return [
462+
463+
mapping = [
466464
slice => ConcreteIFRTArray{T,N}(
467-
XLA.IFRT.AsyncArray(shard, nothing), length.(slice), Sharding.NoShardInfo()
465+
XLA.IFRT.AsyncArray(shard, nothing),
466+
map(length, slice),
467+
Sharding.NoShardInfo(),
468468
) for
469469
(slice, shard, device) in zip(array_slices, single_device_shards, all_devices) if
470470
XLA.is_addressable(device)
471471
]
472+
473+
has_padding(x) || return mapping
474+
475+
mapping_unpadded = Vector{eltype(mapping)}(undef, length(mapping))
476+
for (i, (slice, shard)) in enumerate(mapping)
477+
chop_ends = map(enumerate(slice)) do (i, idx_range)
478+
last(idx_range) > size(x, i) && return last(idx_range) - size(x, i)
479+
return 0
480+
end
481+
482+
if all(iszero, chop_ends)
483+
mapping_unpadded[i] = mapping[i]
484+
else
485+
new_slice = map(zip(slice, chop_ends)) do (idx_range, chop)
486+
chop == 0 && return idx_range
487+
return first(idx_range):(last(idx_range) - chop)
488+
end
489+
if !any(iszero length, new_slice)
490+
mapping_unpadded[i] =
491+
Tuple(new_slice) => shard[map(Base.OneTo length, new_slice)...]
492+
end
493+
end
494+
end
495+
496+
return [
497+
mapping_unpadded[i] for
498+
i in 1:length(mapping_unpadded) if isassigned(mapping_unpadded, i)
499+
]
472500
end
473501

474502
## ConcreteRNG

0 commit comments

Comments
 (0)