Description
openedon May 25, 2017
@MikeInnes suggested in FluxML/Flux.jl#31 that we change the dot-call lowering in such a way as to allow it to be disabled for custom types (which may reside on a GPU or something and support only a small set of functions).
For example, currently x .+ y .* x.^z
is lowered to broadcast((x,y,z) -> x+y*x^z, x,y,z)
. This could instead be lowered to:
if Base.isfusing(x,y,z)
broadcast((x,y) -> x+y*x^z, x,y,z)
else
broadcast(+, x, broadcast(*, y, broadcast(^, x, z)))
end
with Base.isfusing(...) = true
being the default. This would also address #22053 (cc @malmaud).
Pro:
-
Makes life easier in the short run for containers based on external libraries that only support a few efficient operations. They could overload
isfusing
andbroadcast(::typeof(foo), ...)
for a small number of supportedfoo
. -
Makes life easier in specialized applications where it may not be possible to define
broadcast
for general functions, e.g. in Convex.jl where it needs to guarantee convexity.
Con:
-
You can no longer look at an expression like
x .+ y .* x
and know that it fuses into a single loop with no temporaries. -
There is no middle ground. A single non-fusable operand will "spoil" fusion for an entire set of nested dot calls. (To fuse a subset of an expression, you'd have to explicitly assign it to a temporary array.) (We could lower to a nested set of
isfusing
calls, of course, but then you'd get an exponential explosion in lowered code size.)
In the long run, I really think that packages like TensorFlow.jl should exploit the underlying library's ability to define custom operations as callback functions to implement broadcast(f::Function, ...)
for arbitrary f
, at which point they can defining isfusing(...) = true
and get all the benefits of fusion.