Skip to content

Commit bf4607c

Browse files
committed
refactor: clean up broadcasted operator errors
1 parent 1c878f2 commit bf4607c

File tree

1 file changed

+31
-58
lines changed

1 file changed

+31
-58
lines changed

src/OperatorEnumConstruction.jl

Lines changed: 31 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,14 @@ end
110110
return mapping[f]
111111
end
112112

113+
function _unpack_broadcast_function(f)
114+
if f isa Broadcast.BroadcastFunction
115+
return Symbol(f.f), :(Broadcast.BroadcastFunction($(f.f)))
116+
else
117+
return Symbol(f), Symbol(f)
118+
end
119+
end
120+
113121
function empty_all_globals!(; force=true)
114122
if force || islocked(LATEST_LOCK)
115123
lock(LATEST_LOCK) do
@@ -269,46 +277,29 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
269277
f_inside, f_outside, type_requirements, build_converters, internal
270278
)
271279
unary_ex = _extend_unary_operator(f_inside, f_outside, type_requirements, internal)
280+
#! format: off
272281
return quote
273-
local $type_requirements
274-
local $build_converters
275-
local $binary_exists
276-
local $unary_exists
282+
local $type_requirements, $build_converters, $binary_exists, $unary_exists
277283
lock($LATEST_LOCK)
278284
if isa($operators, $OperatorEnum)
279285
$type_requirements = $(on_type == nothing ? Number : on_type)
280286
$build_converters = $(on_type == nothing)
281-
if !haskey(
282-
$(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum, $type_requirements
283-
)
284-
$(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum[$type_requirements] = Dict{
285-
Function,Bool
286-
}()
287+
if !haskey($(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum, $type_requirements)
288+
$(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum[$type_requirements] = Dict{Function,Bool}()
287289
end
288290
if !haskey($(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum, $type_requirements)
289-
$(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum[$type_requirements] = Dict{
290-
Function,Bool
291-
}()
291+
$(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum[$type_requirements] = Dict{Function,Bool}()
292292
end
293293
$binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).operator_enum[$type_requirements]
294294
$unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).operator_enum[$type_requirements]
295295
else
296296
$type_requirements = $(on_type == nothing ? Any : on_type)
297297
$build_converters = false
298-
if !haskey(
299-
$(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum,
300-
$type_requirements,
301-
)
302-
$(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum[$type_requirements] = Dict{
303-
Function,Bool
304-
}()
298+
if !haskey($(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum, $type_requirements)
299+
$(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum[$type_requirements] = Dict{Function,Bool}()
305300
end
306-
if !haskey(
307-
$(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum, $type_requirements
308-
)
309-
$(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum[$type_requirements] = Dict{
310-
Function,Bool
311-
}()
301+
if !haskey($(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum, $type_requirements)
302+
$(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum[$type_requirements] = Dict{Function,Bool}()
312303
end
313304
$binary_exists = $(ALREADY_DEFINED_BINARY_OPERATORS).generic_operator_enum[$type_requirements]
314305
$unary_exists = $(ALREADY_DEFINED_UNARY_OPERATORS).generic_operator_enum[$type_requirements]
@@ -319,13 +310,7 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
319310
empty!($(LATEST_UNARY_OPERATOR_MAPPING))
320311
end
321312
for (op, func) in enumerate($(operators).binops)
322-
local $f_outside =
323-
typeof(func) <: Broadcast.BroadcastFunction ? Symbol(func.f) : Symbol(func)
324-
local $f_inside = if typeof(func) <: Broadcast.BroadcastFunction
325-
:(Broadcast.BroadcastFunction($(func.f)))
326-
else
327-
Symbol(func)
328-
end
313+
local ($f_outside, $f_inside) = $(_unpack_broadcast_function)(func)
329314
local $skip = false
330315
if isdefined(Base, $f_outside)
331316
$f_outside = :(Base.$($f_outside))
@@ -343,13 +328,7 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
343328
end
344329
end
345330
for (op, func) in enumerate($(operators).unaops)
346-
local $f_outside =
347-
typeof(func) <: Broadcast.BroadcastFunction ? Symbol(func.f) : Symbol(func)
348-
local $f_inside = if typeof(func) <: Broadcast.BroadcastFunction
349-
:(Broadcast.BroadcastFunction($(func.f)))
350-
else
351-
Symbol(func)
352-
end
331+
local ($f_outside, $f_inside) = $(_unpack_broadcast_function)(func)
353332
local $skip = false
354333
if isdefined(Base, $f_outside)
355334
$f_outside = :(Base.$($f_outside))
@@ -368,6 +347,7 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
368347
end
369348
unlock($LATEST_LOCK)
370349
end
350+
#! format: on
371351
end
372352

373353
"""
@@ -388,24 +368,6 @@ macro extend_operators(operators, kws...)
388368
if !isa($(operators), $expected_type)
389369
error("You must pass an operator enum to `@extend_operators`.")
390370
end
391-
for bo in $(operators).unaops
392-
!(typeof(bo) <: Broadcast.BroadcastFunction) && continue
393-
!(bo.f in $(operators).unaops) && continue
394-
error(
395-
"Usage of both broadcasted and unboradcasted operator " *
396-
string(bo.f) *
397-
" is ambiguous",
398-
)
399-
end
400-
for bo in $(operators).binops
401-
!(typeof(bo) <: Broadcast.BroadcastFunction) && continue
402-
!(bo.f in $(operators).binops) && continue
403-
error(
404-
"Usage of both broadcasted and unboradcasted operator " *
405-
string(bo.f) *
406-
" is ambiguous",
407-
)
408-
end
409371
$ex
410372
end,
411373
)
@@ -475,6 +437,17 @@ redefine operators for `AbstractExpressionNode` types, as well as `show`, `print
475437
end
476438
end
477439

440+
for ops in (binary_operators, unary_operators), op in ops
441+
if op isa Broadcast.BroadcastFunction &&
442+
(op.f in binary_operators || op.f in unary_operators)
443+
throw(
444+
ArgumentError(
445+
"Usage of both broadcasted and unbroadcasted operator $(op.f) is ambiguous",
446+
),
447+
)
448+
end
449+
end
450+
478451
operators = OperatorEnum(Tuple(binary_operators), Tuple(unary_operators))
479452

480453
if define_helper_functions

0 commit comments

Comments
 (0)