Skip to content

Commit cd2e260

Browse files
authored
treat .= as syntactic sugar for broadcast! (#17510)
* treat .= as syntactic sugar for broadcast! * tests * optimized .= assignment of scalars and vector copies * .= documentation * fix show of .= ops * .-= tests * NEWS for .=
1 parent bc034fc commit cd2e260

File tree

9 files changed

+153
-88
lines changed

9 files changed

+153
-88
lines changed

NEWS.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,11 @@ New language features
1010
* Generators and comprehensions support filtering using `if` ([#550]) and nested
1111
iteration using multiple `for` keywords ([#4867]).
1212

13-
* Broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]),
13+
* Fused broadcasting syntax: ``f.(args...)`` is equivalent to ``broadcast(f, args...)`` ([#15032]),
1414
and nested `f.(g.(args...))` calls are fused into a single `broadcast` loop ([#17300]).
15+
Similarly, the syntax `x .= ...` is equivalent to a `broadcast!(identity, x, ...)`
16+
call and fuses with nested "dot" calls; also, `x .+= y` and similar is now
17+
equivalent to `x .= x .+ y`, rather than `=` ([#17510]).
1518

1619
* Macro expander functions are now generic, so macros can have multiple definitions
1720
(e.g. for different numbers of arguments, or optional arguments) ([#8846], [#9627]).
@@ -357,3 +360,4 @@ Deprecated or removed
357360
[#17393]: https://github.com/JuliaLang/julia/issues/17393
358361
[#17402]: https://github.com/JuliaLang/julia/issues/17402
359362
[#17404]: https://github.com/JuliaLang/julia/issues/17404
363+
[#17510]: https://github.com/JuliaLang/julia/issues/17510

base/broadcast.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,15 @@ export broadcast_getindex, broadcast_setindex!
1515
broadcast(f) = f()
1616
broadcast(f, x::Number...) = f(x...)
1717

18+
# special cases for "X .= ..." (broadcast!) assignments
19+
broadcast!(::typeof(identity), X::AbstractArray, x::Number) = fill!(X, x)
20+
broadcast!(f, X::AbstractArray) = fill!(X, f())
21+
broadcast!(f, X::AbstractArray, x::Number...) = fill!(X, f(x...))
22+
function broadcast!{T,S,N}(::typeof(identity), x::AbstractArray{T,N}, y::AbstractArray{S,N})
23+
check_broadcast_shape(size(x), size(y))
24+
copy!(x, y)
25+
end
26+
1827
## Calculate the broadcast shape of the arguments, or error if incompatible
1928
# array inputs
2029
broadcast_shape() = ()

base/show.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -408,8 +408,10 @@ show_unquoted(io::IO, ex, ::Int,::Int) = show(io, ex)
408408
const indent_width = 4
409409
const quoted_syms = Set{Symbol}([:(:),:(::),:(:=),:(=),:(==),:(!=),:(===),:(!==),:(=>),:(>=),:(<=)])
410410
const uni_ops = Set{Symbol}([:(+), :(-), :(!), :(¬), :(~), :(<:), :(>:), :(), :(), :()])
411-
const expr_infix_wide = Set{Symbol}([:(=), :(+=), :(-=), :(*=), :(/=), :(\=), :(&=),
412-
:(|=), :($=), :(>>>=), :(>>=), :(<<=), :(&&), :(||), :(<:), :(=>), :(÷=)])
411+
const expr_infix_wide = Set{Symbol}([
412+
:(=), :(+=), :(-=), :(*=), :(/=), :(\=), :(^=), :(&=), :(|=), :(÷=), :(%=), :(>>>=), :(>>=), :(<<=),
413+
:(.=), :(.+=), :(.-=), :(.*=), :(./=), :(.\=), :(.^=), :(.&=), :(.|=), :(.÷=), :(.%=), :(.>>>=), :(.>>=), :(.<<=),
414+
:(&&), :(||), :(<:), :(=>), :($=)])
413415
const expr_infix = Set{Symbol}([:(:), :(->), Symbol("::")])
414416
const expr_infix_any = union(expr_infix, expr_infix_wide)
415417
const all_ops = union(quoted_syms, uni_ops, expr_infix_any)

doc/manual/arrays.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ function elementwise:
566566
1.71056 0.847604
567567
1.73659 0.873631
568568

569-
Elementwise operators such as ``.+`` and ``.*`` perform broadcasting if necessary. There is also a :func:`broadcast!` function to specify an explicit destination, and :func:`broadcast_getindex` and :func:`broadcast_setindex!` that broadcast the indices before indexing. Moreover, ``f.(args...)`` is equivalent to ``broadcast(f, args...)``, providing a convenient syntax to broadcast any function (:ref:`man-dot-vectorizing`:.).
569+
Elementwise operators such as ``.+`` and ``.*`` perform broadcasting if necessary. There is also a :func:`broadcast!` function to specify an explicit destination, and :func:`broadcast_getindex` and :func:`broadcast_setindex!` that broadcast the indices before indexing. Moreover, ``f.(args...)`` is equivalent to ``broadcast(f, args...)``, providing a convenient syntax to broadcast any function (:ref:`man-dot-vectorizing`:).
570570

571571
Implementation
572572
--------------

doc/manual/functions.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -652,9 +652,20 @@ the fusion stops as soon as a "non-dot" function is encountered; for example,
652652
in ``sin.(sort(cos.(X)))`` the ``sin`` and ``cos`` loops cannot be merged
653653
because of the intervening ``sort`` function.
654654

655+
Finally, the maximum efficiency is typically achieved when the output
656+
array of a vectorized operation is *pre-allocated*, so that repeated
657+
calls do not allocate new arrays over and over again for the results
658+
(:ref:`man-preallocation`:). A convenient syntax for this is
659+
``X .= ...``, which is equivalent to ``broadcast!(identity, X, ...)``
660+
except that, as above, the ``broadcast!`` loop is fused with any nested
661+
"dot" calls. For example, ``X .= sin.(Y)`` is equivalent to
662+
``broadcast!(sin, X, Y)``, overwriting ``X`` with ``sin.(Y)`` in-place.
663+
655664
(In future versions of Julia, operators like ``.*`` will also be handled with
656665
the same mechanism: they will be equivalent to ``broadcast`` calls and
657-
will be fused with other nested "dot" calls.)
666+
will be fused with other nested "dot" calls. ``x .+= y`` is equivalent
667+
to ``x .= x .+ y`` and will eventually result in a fused in-place assignment.
668+
Similarly for ``.*=`` etcetera.)
658669

659670
Further Reading
660671
---------------

doc/manual/performance-tips.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -944,7 +944,10 @@ above, we could have passed a :class:`SubArray` rather than an :class:`Array`,
944944
had we so desired.
945945

946946
Taken to its extreme, pre-allocation can make your code uglier, so
947-
performance measurements and some judgment may be required.
947+
performance measurements and some judgment may be required. However,
948+
for "vectorized" (element-wise) functions, the convenient syntax
949+
``x .= f.(y)`` can be used for in-place operations with fused loops
950+
and no temporary arrays (:ref:`dot-vectorizing`).
948951

949952

950953
Avoid string interpolation for I/O

src/julia-syntax.scm

Lines changed: 95 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,12 +1418,12 @@
14181418
`(call ,(cadr e) ,(expand-forms a) ,(expand-forms b))))))
14191419

14201420
;; convert `a+=b` to `a=a+b`
1421-
(define (expand-update-operator- op lhs rhs declT)
1421+
(define (expand-update-operator- op op= lhs rhs declT)
14221422
(let ((e (remove-argument-side-effects lhs)))
14231423
`(block ,@(cdr e)
14241424
,(if (null? declT)
1425-
`(= ,(car e) (call ,op ,(car e) ,rhs))
1426-
`(= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs))))))
1425+
`(,op= ,(car e) (call ,op ,(car e) ,rhs))
1426+
`(,op= ,(car e) (call ,op (:: ,(car e) ,(car declT)) ,rhs))))))
14271427

14281428
(define (partially-expand-ref e)
14291429
(let ((a (cadr e))
@@ -1443,31 +1443,32 @@
14431443
,@(append stmts stuff)
14441444
(call getindex ,arr ,@new-idxs))))))
14451445

1446-
(define (expand-update-operator op lhs rhs . declT)
1446+
(define (expand-update-operator op op= lhs rhs . declT)
14471447
(cond ((and (pair? lhs) (eq? (car lhs) 'ref))
14481448
;; expand indexing inside op= first, to remove "end" and ":"
14491449
(let* ((ex (partially-expand-ref lhs))
14501450
(stmts (butlast (cdr ex)))
14511451
(refex (last (cdr ex)))
14521452
(nuref `(ref ,(caddr refex) ,@(cdddr refex))))
14531453
`(block ,@stmts
1454-
,(expand-update-operator- op nuref rhs declT))))
1454+
,(expand-update-operator- op op= nuref rhs declT))))
14551455
((and (pair? lhs) (eq? (car lhs) '|::|))
14561456
;; (+= (:: x T) rhs)
14571457
(let ((e (remove-argument-side-effects (cadr lhs)))
14581458
(T (caddr lhs)))
14591459
`(block ,@(cdr e)
1460-
,(expand-update-operator op (car e) rhs T))))
1460+
,(expand-update-operator op op= (car e) rhs T))))
14611461
(else
1462-
(expand-update-operator- op lhs rhs declT))))
1462+
(expand-update-operator- op op= lhs rhs declT))))
14631463

14641464
(define (lower-update-op e)
14651465
(expand-forms
1466-
(expand-update-operator
1467-
(let ((str (string (car e))))
1468-
(symbol (string.sub str 0 (- (length str) 1))))
1469-
(cadr e)
1470-
(caddr e))))
1466+
(let ((str (string (car e))))
1467+
(expand-update-operator
1468+
(symbol (string.sub str 0 (- (length str) 1)))
1469+
(if (= (string.char str 0) #\.) '.= '=)
1470+
(cadr e)
1471+
(caddr e)))))
14711472

14721473
(define (expand-and e)
14731474
(let ((e (cdr (flatten-ex '&& e))))
@@ -1546,11 +1547,9 @@
15461547
(cadr expr) ;; eta reduce `x->f(x)` => `f`
15471548
`(-> ,argname (block ,@splat ,expr)))))
15481549

1549-
(define (getfield-field? x) ; whether x from (|.| f x) is a getfield call
1550-
(or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$)))
1551-
1552-
;; fuse nested calls to f.(args...) into a single broadcast call
1553-
(define (expand-fuse-broadcast f args)
1550+
; fuse nested calls to expr == f.(args...) into a single broadcast call,
1551+
; or a broadcast! call if lhs is non-null.
1552+
(define (expand-fuse-broadcast lhs rhs)
15541553
(define (fuse? e) (and (pair? e) (eq? (car e) 'fuse)))
15551554
(define (anyfuse? exprs)
15561555
(if (null? exprs) #f (if (fuse? (car exprs)) #t (anyfuse? (cdr exprs)))))
@@ -1594,72 +1593,83 @@
15941593
oldarg))
15951594
fargs args)))
15961595
(let ,fbody ,@(reverse (fuse-lets fargs args '()))))))
1597-
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
1598-
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
1599-
(define (sk args kwargs pargs)
1600-
(if (null? args)
1601-
(cons kwargs pargs)
1602-
(if (kwarg? (car args))
1603-
(sk (cdr args) (cons (car args) kwargs) pargs)
1604-
(sk (cdr args) kwargs (cons (car args) pargs)))))
1605-
(if (has-parameters? args)
1606-
(sk (reverse (cdr args)) (cdar args) '())
1607-
(sk (reverse args) '() '())))
1608-
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
1609-
(if (and (pair? e) (eq? (car e) '|.|) (not (getfield-field? (caddr e))))
1610-
(make-fuse (cadr e) (cdaddr e))
1611-
e))
1612-
(let* ((kws.args (split-kwargs args))
1613-
(kws (car kws.args))
1614-
(args (cdr kws.args)) ; fusing occurs on positional args only
1615-
(args_ (map dot-to-fuse args)))
1616-
(if (anyfuse? args_)
1617-
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
1618-
`(fuse ,(to-lambda f args kws) ,args_))))
1596+
(define (dot-to-fuse e) ; convert e == (. f (tuple args)) to (fuse f args)
1597+
(define (make-fuse f args) ; check for nested (fuse f args) exprs and combine
1598+
(define (split-kwargs args) ; return (cons keyword-args positional-args) extracted from args
1599+
(define (sk args kwargs pargs)
1600+
(if (null? args)
1601+
(cons kwargs pargs)
1602+
(if (kwarg? (car args))
1603+
(sk (cdr args) (cons (car args) kwargs) pargs)
1604+
(sk (cdr args) kwargs (cons (car args) pargs)))))
1605+
(if (has-parameters? args)
1606+
(sk (reverse (cdr args)) (cdar args) '())
1607+
(sk (reverse args) '() '())))
1608+
(let* ((kws.args (split-kwargs args))
1609+
(kws (car kws.args))
1610+
(args (cdr kws.args)) ; fusing occurs on positional args only
1611+
(args_ (map dot-to-fuse args)))
1612+
(if (anyfuse? args_)
1613+
`(fuse ,(fuse-funcs (to-lambda f args kws) args_) ,(fuse-args args_))
1614+
`(fuse ,(to-lambda f args kws) ,args_))))
1615+
(if (and (pair? e) (eq? (car e) '|.|))
1616+
(let ((f (cadr e)) (x (caddr e)))
1617+
(if (or (eq? (car x) 'quote) (eq? (car x) 'inert) (eq? (car x) '$))
1618+
`(call (core getfield) ,f ,x)
1619+
(make-fuse f (cdr x))))
1620+
e))
16191621
; given e == (fuse lambda args), compress the argument list by removing (pure)
16201622
; duplicates in args, inlining literals, and moving any varargs to the end:
16211623
(define (compress-fuse e)
16221624
(define (findfarg arg args fargs) ; for arg in args, return corresponding farg
16231625
(if (eq? arg (car args))
16241626
(car fargs)
16251627
(findfarg arg (cdr args) (cdr fargs))))
1626-
(let ((f (cadr e))
1627-
(args (caddr e)))
1628-
(define (cf old-fargs old-args new-fargs new-args renames varfarg vararg)
1629-
(if (null? old-args)
1630-
(let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs)))
1631-
(nargs (if (null? vararg) new-args (cons vararg new-args))))
1632-
`(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames))
1633-
,(reverse nargs)))
1634-
(let ((farg (car old-fargs)) (arg (car old-args)))
1635-
(cond
1636-
((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument
1637-
(if (null? varfarg)
1638-
(cf (cdr old-fargs) (cdr old-args)
1639-
new-fargs new-args renames farg arg)
1640-
(if (eq? (cadr vararg) (cadr arg))
1628+
(if (fuse? e)
1629+
(let ((f (cadr e))
1630+
(args (caddr e)))
1631+
(define (cf old-fargs old-args new-fargs new-args renames varfarg vararg)
1632+
(if (null? old-args)
1633+
(let ((nfargs (if (null? varfarg) new-fargs (cons varfarg new-fargs)))
1634+
(nargs (if (null? vararg) new-args (cons vararg new-args))))
1635+
`(fuse (-> (tuple ,@(reverse nfargs)) ,(replace-vars (caddr f) renames))
1636+
,(reverse nargs)))
1637+
(let ((farg (car old-fargs)) (arg (car old-args)))
1638+
(cond
1639+
((and (vararg? farg) (vararg? arg)) ; arg... must be the last argument
1640+
(if (null? varfarg)
16411641
(cf (cdr old-fargs) (cdr old-args)
1642-
new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames)
1643-
varfarg vararg)
1644-
(error "multiple splatted args cannot be fused into a single broadcast"))))
1645-
((number? arg) ; inline numeric literals
1646-
(cf (cdr old-fargs) (cdr old-args)
1647-
new-fargs new-args
1648-
(cons (cons farg arg) renames)
1649-
varfarg vararg))
1650-
((and (symbol? arg) (memq arg new-args)) ; combine duplicate args
1651-
; (note: calling memq for every arg is O(length(args)^2) ...
1652-
; ... would be better to replace with a hash table if args is long)
1653-
(cf (cdr old-fargs) (cdr old-args)
1654-
new-fargs new-args
1655-
(cons (cons farg (findfarg arg new-args new-fargs)) renames)
1656-
varfarg vararg))
1657-
(else
1658-
(cf (cdr old-fargs) (cdr old-args)
1659-
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
1660-
(cf (cdadr f) args '() '() '() '() '())))
1661-
(let ((e (compress-fuse (make-fuse f args)))) ; an expression '(fuse func args)
1662-
(expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e)))))
1642+
new-fargs new-args renames farg arg)
1643+
(if (eq? (cadr vararg) (cadr arg))
1644+
(cf (cdr old-fargs) (cdr old-args)
1645+
new-fargs new-args (cons (cons (cadr farg) (cadr varfarg)) renames)
1646+
varfarg vararg)
1647+
(error "multiple splatted args cannot be fused into a single broadcast"))))
1648+
((number? arg) ; inline numeric literals
1649+
(cf (cdr old-fargs) (cdr old-args)
1650+
new-fargs new-args
1651+
(cons (cons farg arg) renames)
1652+
varfarg vararg))
1653+
((and (symbol? arg) (memq arg new-args)) ; combine duplicate args
1654+
; (note: calling memq for every arg is O(length(args)^2) ...
1655+
; ... would be better to replace with a hash table if args is long)
1656+
(cf (cdr old-fargs) (cdr old-args)
1657+
new-fargs new-args
1658+
(cons (cons farg (findfarg arg new-args new-fargs)) renames)
1659+
varfarg vararg))
1660+
(else
1661+
(cf (cdr old-fargs) (cdr old-args)
1662+
(cons farg new-fargs) (cons arg new-args) renames varfarg vararg))))))
1663+
(cf (cdadr f) args '() '() '() '() '()))
1664+
e)) ; (not (fuse? e))
1665+
(let ((e (compress-fuse (dot-to-fuse rhs)))) ; an expression '(fuse func args) if expr is a dot call
1666+
(if (fuse? e)
1667+
(if (null? lhs)
1668+
(expand-forms `(call broadcast ,(from-lambda (cadr e)) ,@(caddr e)))
1669+
(expand-forms `(call broadcast! ,(from-lambda (cadr e)) ,lhs ,@(caddr e))))
1670+
(if (null? lhs)
1671+
(expand-forms e)
1672+
(expand-forms `(call broadcast! identity ,lhs ,e))))))
16631673

16641674
;; table mapping expression head to a function expanding that form
16651675
(define expand-table
@@ -1697,13 +1707,11 @@
16971707

16981708
'|.|
16991709
(lambda (e) ; e = (|.| f x)
1700-
(let ((f (cadr e))
1701-
(x (caddr e)))
1702-
(if (getfield-field? x)
1703-
`(call (core getfield) ,(expand-forms f) ,(expand-forms x))
1704-
; otherwise, came from f.(args...) --> broadcast(f, args...),
1705-
; where we want to fuse with any nested broadcast calls.
1706-
(expand-fuse-broadcast f (cdr x)))))
1710+
(expand-fuse-broadcast '() e))
1711+
1712+
'.=
1713+
(lambda (e)
1714+
(expand-fuse-broadcast (cadr e) (caddr e)))
17071715

17081716
'|<:| syntactic-op-to-call
17091717
'|>:| syntactic-op-to-call
@@ -2008,11 +2016,16 @@
20082016
'%= lower-update-op
20092017
'.%= lower-update-op
20102018
'|\|=| lower-update-op
2019+
'|.\|=| lower-update-op
20112020
'&= lower-update-op
2021+
'.&= lower-update-op
20122022
'$= lower-update-op
20132023
'<<= lower-update-op
2024+
'.<<= lower-update-op
20142025
'>>= lower-update-op
2026+
'.>>= lower-update-op
20152027
'>>>= lower-update-op
2028+
'.>>>= lower-update-op
20162029

20172030
':
20182031
(lambda (e)

test/broadcast.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,25 @@ let x = [1:4;]
248248
@test sin.(f17300kw.(x, y=1)) == sin.(f17300kw.(x; y=1)) == sin.(x .+ 1)
249249
end
250250

251+
# PR #17510: Fused in-place assignment
252+
let x = [1:4;], y = x
253+
y .= 2:5
254+
@test y === x == [2:5;]
255+
y .= factorial.(x)
256+
@test y === x == [2,6,24,120]
257+
y .= 7
258+
@test y === x == [7,7,7,7]
259+
y .= factorial.(3)
260+
@test y === x == [6,6,6,6]
261+
f17510() = 9
262+
y .= f17510.()
263+
@test y === x == [9,9,9,9]
264+
y .-= 1
265+
@test y === x == [8,8,8,8]
266+
y .-= 1:4
267+
@test y === x == [7,6,5,4]
268+
end
269+
251270
# PR 16988
252271
@test Base.promote_op(+, Bool) === Int
253272
@test isa(broadcast(+, [true]), Array{Int,1})

test/show.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,3 +557,7 @@ end
557557
@test repr(:(x for x in y if aa for z in w if bb)) == ":(x for x = y if aa for z = w if bb)"
558558
@test repr(:([x for x = y])) == ":([x for x = y])"
559559
@test repr(:([x for x = y if z])) == ":([x for x = y if z])"
560+
561+
for op in (:(.=), :(.+=), :(.&=))
562+
@test repr(parse("x $op y")) == ":(x $op y)"
563+
end

0 commit comments

Comments
 (0)