Skip to content

Commit 0c8aab9

Browse files
committed
refactor: clean up broadcasted operators
1 parent 1c878f2 commit 0c8aab9

File tree

1 file changed

+35
-58
lines changed

1 file changed

+35
-58
lines changed

src/OperatorEnumConstruction.jl

Lines changed: 35 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,28 @@ end
110110
return mapping[f]
111111
end
112112

113+
function _unpack_broadcast_function(f)
114+
if f isa Broadcast.BroadcastFunction
115+
return Symbol(f.f), :(Broadcast.BroadcastFunction($(f.f)))
116+
else
117+
return Symbol(f), Symbol(f)
118+
end
119+
end
120+
121+
function _validate_no_ambiguous_broadcasts(operators::AbstractOperatorEnum)
122+
for ops in (operators.binops, operators.unaops), op in ops
123+
if op isa Broadcast.BroadcastFunction &&
124+
(op.f in operators.binops || op.f in operators.unaops)
125+
throw(
126+
ArgumentError(
127+
"Usage of both broadcasted and unbroadcasted operator `$(op.f)` is ambiguous",
128+
),
129+
)
130+
end
131+
end
132+
return nothing
133+
end
134+
113135
function empty_all_globals!(; force=true)
114136
if force || islocked(LATEST_LOCK)
115137
lock(LATEST_LOCK) do
@@ -269,46 +291,30 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
269291
f_inside, f_outside, type_requirements, build_converters, internal
270292
)
271293
unary_ex = _extend_unary_operator(f_inside, f_outside, type_requirements, internal)
294+
#! format: off
272295
return quote
273-
local $type_requirements
274-
local $build_converters
275-
local $binary_exists
276-
local $unary_exists
296+
local $type_requirements, $build_converters, $binary_exists, $unary_exists
297+
$(_validate_no_ambiguous_broadcasts)($operators)
277298
lock($LATEST_LOCK)
278299
if isa($operators, $OperatorEnum)
279300
$type_requirements = $(on_type == nothing ? Number : on_type)
280301
$build_converters = $(on_type == nothing)
281-
if !haskey(
282-
$(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum, $type_requirements
283-
)
284-
$(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum[$type_requirements] = Dict{
285-
Function,Bool
286-
}()
302+
if !haskey($(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum, $type_requirements)
303+
$(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum[$type_requirements] = Dict{Function,Bool}()
287304
end
288305
if !haskey($(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum, $type_requirements)
289-
$(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum[$type_requirements] = Dict{
290-
Function,Bool
291-
}()
306+
$(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum[$type_requirements] = Dict{Function,Bool}()
292307
end
293308
$binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum[$type_requirements]
294309
$unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum[$type_requirements]
295310
else
296311
$type_requirements = $(on_type == nothing ? Any : on_type)
297312
$build_converters = false
298-
if !haskey(
299-
$(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum,
300-
$type_requirements,
301-
)
302-
$(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum[$type_requirements] = Dict{
303-
Function,Bool
304-
}()
313+
if !haskey($(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum, $type_requirements)
314+
$(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum[$type_requirements] = Dict{Function,Bool}()
305315
end
306-
if !haskey(
307-
$(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum, $type_requirements
308-
)
309-
$(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum[$type_requirements] = Dict{
310-
Function,Bool
311-
}()
316+
if !haskey($(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum, $type_requirements)
317+
$(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum[$type_requirements] = Dict{Function,Bool}()
312318
end
313319
$binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum[$type_requirements]
314320
$unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum[$type_requirements]
@@ -319,13 +325,7 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
319325
empty!($(LATEST_UNARY_OPERATOR_MAPPING))
320326
end
321327
for (op, func) in enumerate($(operators).binops)
322-
local $f_outside =
323-
typeof(func) <: Broadcast.BroadcastFunction ? Symbol(func.f) : Symbol(func)
324-
local $f_inside = if typeof(func) <: Broadcast.BroadcastFunction
325-
:(Broadcast.BroadcastFunction($(func.f)))
326-
else
327-
Symbol(func)
328-
end
328+
local ($f_outside, $f_inside) = $(_unpack_broadcast_function)(func)
329329
local $skip = false
330330
if isdefined(Base, $f_outside)
331331
$f_outside = :(Base.$($f_outside))
@@ -343,13 +343,7 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
343343
end
344344
end
345345
for (op, func) in enumerate($(operators).unaops)
346-
local $f_outside =
347-
typeof(func) <: Broadcast.BroadcastFunction ? Symbol(func.f) : Symbol(func)
348-
local $f_inside = if typeof(func) <: Broadcast.BroadcastFunction
349-
:(Broadcast.BroadcastFunction($(func.f)))
350-
else
351-
Symbol(func)
352-
end
346+
local ($f_outside, $f_inside) = $(_unpack_broadcast_function)(func)
353347
local $skip = false
354348
if isdefined(Base, $f_outside)
355349
$f_outside = :(Base.$($f_outside))
@@ -368,6 +362,7 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
368362
end
369363
unlock($LATEST_LOCK)
370364
end
365+
#! format: on
371366
end
372367

373368
"""
@@ -388,24 +383,6 @@ macro extend_operators(operators, kws...)
388383
if !isa($(operators), $expected_type)
389384
error("You must pass an operator enum to `@extend_operators`.")
390385
end
391-
for bo in $(operators).unaops
392-
!(typeof(bo) <: Broadcast.BroadcastFunction) && continue
393-
!(bo.f in $(operators).unaops) && continue
394-
error(
395-
"Usage of both broadcasted and unboradcasted operator " *
396-
string(bo.f) *
397-
" is ambiguous",
398-
)
399-
end
400-
for bo in $(operators).binops
401-
!(typeof(bo) <: Broadcast.BroadcastFunction) && continue
402-
!(bo.f in $(operators).binops) && continue
403-
error(
404-
"Usage of both broadcasted and unboradcasted operator " *
405-
string(bo.f) *
406-
" is ambiguous",
407-
)
408-
end
409386
$ex
410387
end,
411388
)

0 commit comments

Comments
 (0)