Skip to content

Commit a7eca00

Browse files
tpappvtjnash
andauthored
Make Broadcast.result_style work on styles with fields. (#50938)
Fixes #50937. --------- Co-authored-by: Jameson Nash <vtjnash@gmail.com>
1 parent 90e3901 commit a7eca00

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

base/broadcast.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,9 @@ Base.Broadcast.DefaultArrayStyle{1}()
441441
function result_style end
442442

443443
result_style(s::BroadcastStyle) = s
444-
result_style(s1::S, s2::S) where S<:BroadcastStyle = S()
444+
function result_style(s1::S, s2::S) where S<:BroadcastStyle
445+
s1 s2 ? s1 : error("inconsistent broadcast styles, custom rule needed")
446+
end
445447
# Test both orders so users typically only have to declare one order
446448
result_style(s1, s2) = result_join(s1, s2, BroadcastStyle(s1, s2), BroadcastStyle(s2, s1))
447449

@@ -457,7 +459,8 @@ result_join(::Any, ::Any, s::BroadcastStyle, ::Unknown) = s
457459
result_join(::AbstractArrayStyle, ::AbstractArrayStyle, ::Unknown, ::Unknown) =
458460
ArrayConflict()
459461
# Fallbacks in case users define `rule` for both argument-orders (not recommended)
460-
result_join(::Any, ::Any, ::S, ::S) where S<:BroadcastStyle = S()
462+
result_join(::Any, ::Any, s1::S, s2::S) where S<:BroadcastStyle = result_style(s1, s2)
463+
461464
@noinline function result_join(::S, ::T, ::U, ::V) where {S,T,U,V}
462465
error("""
463466
conflicting broadcast rules defined

test/broadcast.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,19 @@ end
11421142
@test CartesianIndex(1,2) .+ [CartesianIndex(3,4), CartesianIndex(5,6)] == [CartesianIndex(4, 6), CartesianIndex(6, 8)]
11431143
end
11441144

1145+
struct MyBroadcastStyleWithField <: Broadcast.BroadcastStyle
1146+
i::Int
1147+
end
1148+
# asymmetry intended
1149+
Base.BroadcastStyle(a::MyBroadcastStyleWithField, b::MyBroadcastStyleWithField) = a
1150+
1151+
@testset "issue #50937: styles that have fields" begin
1152+
@test Broadcast.result_style(MyBroadcastStyleWithField(1), MyBroadcastStyleWithField(1)) ==
1153+
MyBroadcastStyleWithField(1)
1154+
@test_throws ErrorException Broadcast.result_style(MyBroadcastStyleWithField(1),
1155+
MyBroadcastStyleWithField(2))
1156+
end
1157+
11451158
# test that `Broadcast` definition is defined as total and eligible for concrete evaluation
11461159
import Base.Broadcast: BroadcastStyle, DefaultArrayStyle
11471160
@test Base.infer_effects(BroadcastStyle, (DefaultArrayStyle{1},DefaultArrayStyle{2},)) |>

0 commit comments

Comments
 (0)