Skip to content

Commit b856927

Browse files
Format Julia code (#1477)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 6c29cd7 commit b856927

File tree

4 files changed

+39
-24
lines changed

4 files changed

+39
-24
lines changed

src/ConcreteRArray.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,12 +371,15 @@ function Base.setindex!(a::ConcreteIFRTArray, v, args::Vararg{Int,N}) where {N}
371371
return a
372372
end
373373

374-
@inline function Base.similar(::Type{<:ConcretePJRTArray}, ::Type{S}, dims::Dims;
375-
client::Union{Nothing,XLA.PJRT.Client}=nothing,
376-
idx::Union{Int,Nothing}=nothing,
377-
device::Union{Nothing,XLA.PJRT.Device}=nothing,
378-
sharding::Sharding.AbstractSharding=Sharding.NoSharding()
379-
) where {S}
374+
@inline function Base.similar(
375+
::Type{<:ConcretePJRTArray},
376+
::Type{S},
377+
dims::Dims;
378+
client::Union{Nothing,XLA.PJRT.Client}=nothing,
379+
idx::Union{Int,Nothing}=nothing,
380+
device::Union{Nothing,XLA.PJRT.Device}=nothing,
381+
sharding::Sharding.AbstractSharding=Sharding.NoSharding(),
382+
) where {S}
380383
client = client === nothing ? XLA.default_backend() : client
381384

382385
if idx isa Int && device === nothing
@@ -385,7 +388,9 @@ end
385388

386389
sdata, sharding = sharding(client, device, S, dims)
387390

388-
return ConcretePJRTArray{S,length(dims),length(sdata),typeof(sharding)}(sdata, dims, sharding)
391+
return ConcretePJRTArray{S,length(dims),length(sdata),typeof(sharding)}(
392+
sdata, dims, sharding
393+
)
389394
end
390395

391396
function Base.similar(

src/Sharding.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -418,9 +418,7 @@ function (sharding::NamedSharding)(
418418
return data, ShardInfo(sharding, device_to_array_slices)
419419
end
420420

421-
function (sharding::NamedSharding)(
422-
client::XLA.PJRT.Client, _, S::Type, dims::Dims
423-
)
421+
function (sharding::NamedSharding)(client::XLA.PJRT.Client, _, S::Type, dims::Dims)
424422
if !issorted(sharding.mesh.logical_device_ids)
425423
error("PJRT doesn't support non-iota meshes. Use IFRT instead.")
426424
end
@@ -431,7 +429,13 @@ function (sharding::NamedSharding)(
431429

432430
data = ntuple(length(sharding.mesh)) do i
433431
Base.@_inline_meta
434-
Base.similar(XLA.PJRT.AsyncBuffer, S, Dims(length.(device_to_array_slices[i])); client, device=XLA.get_device(client, sharding.mesh.device_ids[i]))
432+
Base.similar(
433+
XLA.PJRT.AsyncBuffer,
434+
S,
435+
Dims(length.(device_to_array_slices[i]));
436+
client,
437+
device=XLA.get_device(client, sharding.mesh.device_ids[i]),
438+
)
435439
end
436440

437441
return data, ShardInfo(sharding, device_to_array_slices)
@@ -774,7 +778,6 @@ function (sharding::Replicated)(client::XLA.PJRT.Client, dev, S::Type, dims::Dim
774778
return (NamedSharding(sharding, length(dims)))(client, dev, S, dims)
775779
end
776780

777-
778781
function sharding_to_array_slices(sharding::Replicated, size_x; kwargs...)
779782
return sharding_to_array_slices(
780783
NamedSharding(sharding, length(size_x)), size_x; kwargs...
@@ -963,13 +966,17 @@ function (sharding::HloSharding)(
963966
return data, ShardInfo(sharding, device_to_array_slices)
964967
end
965968

966-
function (sharding::HloSharding)(
967-
client::XLA.PJRT.Client, ::Nothing, S::Type, dims::Dims
968-
)
969+
function (sharding::HloSharding)(client::XLA.PJRT.Client, ::Nothing, S::Type, dims::Dims)
969970
device_to_array_slices = sharding_to_array_slices(sharding, dims; client)
970971

971-
data = ntuple(length(sharding.mesh)) do i
972-
Base.similar(XLA.PJRT.AsyncBuffer, S, Dims(length.(device_to_array_slices[i])); client, device=XLA.get_device(client, sharding.mesh.device_ids[i]))
972+
data = ntuple(length(sharding.mesh)) do i
973+
Base.similar(
974+
XLA.PJRT.AsyncBuffer,
975+
S,
976+
Dims(length.(device_to_array_slices[i]));
977+
client,
978+
device=XLA.get_device(client, sharding.mesh.device_ids[i]),
979+
)
973980
end
974981

975982
return data, ShardInfo(sharding, device_to_array_slices)

src/xla/PJRT/AsyncBuffer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,4 @@ end
1818

1919
@inline function Base.similar(::Type{AsyncBuffer}, args...; kwargs...)
2020
return AsyncBuffer(Base.similar(Buffer, args...; kwargs...)::Buffer, nothing)
21-
end
21+
end

src/xla/PJRT/Buffer.jl

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,14 @@ function Base.similar(a::Buffer, dims::Dims)
7373
return Buffer(buffer)
7474
end
7575

76-
@inline function Base.similar(::Type{Buffer}, S::Type, dims::Dims;
77-
client::Union{Nothing,XLA.PJRT.Client}=nothing,
78-
idx::Union{Int,Nothing}=nothing,
79-
device::Union{Nothing,XLA.PJRT.Device}=nothing,
80-
)
76+
@inline function Base.similar(
77+
::Type{Buffer},
78+
S::Type,
79+
dims::Dims;
80+
client::Union{Nothing,XLA.PJRT.Client}=nothing,
81+
idx::Union{Int,Nothing}=nothing,
82+
device::Union{Nothing,XLA.PJRT.Device}=nothing,
83+
)
8184
client = client === nothing ? XLA.default_backend() : client
8285

8386
if device === nothing
@@ -108,7 +111,7 @@ end
108111
end
109112

110113
function Base.similar(a::Buffer, S::Type, dims::Dims)
111-
Base.similar(Buffer, S, dims; client=XLA.client(a), device=XLA.device(a))
114+
return Base.similar(Buffer, S, dims; client=XLA.client(a), device=XLA.device(a))
112115
end
113116

114117
@inline function free_buffer(buffer::Buffer)

0 commit comments

Comments
 (0)