Skip to content

Performance issues with fusing 12+ broadcasts #22255

Closed
@ChrisRackauckas

Description

@ChrisRackauckas

Updated OP

While working through the issue, @jebej identified that the problem is fusing 12+ broadcasts in this comment (#22255 (comment)) which contains an MWE.

Original OP

In OrdinaryDiffEq.jl, I see a 10x performance regression due to using broadcast. With the testing code:

const ζ = 0.5
const ω₀ = 10.0

using OrdinaryDiffEq, DiffEqBase

const y₀ = Float64[0., sqrt(1-ζ^2)*ω₀]
const A = 1
const ϕ = 0


function f(t::Float64)
    α = sqrt(1-ζ^2)*ω₀
    x = A*exp(-ζ*ω₀*t)*sin*t + ϕ)
    p = A*exp(-ζ*ω₀*t)*(-ζ*ω₀*sin*t + ϕ) + α*cos*t + ϕ))
    return [x,p]
end

function df(t::Float64, y::Vector{Float64}, dy::Vector{Float64})
    dy[1] = y[2]
    dy[2] = -2*ζ*ω₀*y[2] - ω₀^2*y[1]
    return nothing
end

const T = [0.,10.]
using BenchmarkTools
prob = ODEProblem(df,y₀,(T[1],T[2]))
@benchmark init($prob,Tsit5(),dense=false,dt=1/10)
@benchmark solve($prob,Tsit5(),dense=false,dt=1/10)

using ProfileView

@profile for i in 1:1000; solve(prob,Tsit5(),dense=false,dt=1/10); end
ProfileView.view()

I get a 10x regression by changing the inner loop from:

function perform_step!(integrator,cache::Tsit5Cache,f=integrator.f)
  @unpack t,dt,uprev,u,k = integrator
  uidx = eachindex(integrator.uprev)
  @unpack c1,c2,c3,c4,c5,c6,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a62,a63,a64,a65,a71,a72,a73,a74,a75,a76,b1,b2,b3,b4,b5,b6,b7 = cache.tab
  @unpack k1,k2,k3,k4,k5,k6,k7,utilde,tmp,atmp = cache
  a = dt*a21
  for i in uidx
    tmp[i] = @muladd uprev[i]+a*k1[i]
  end
  f(@muladd(t+c1*dt),tmp,k2)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a31*k1[i]+a32*k2[i])
  end
  f(@muladd(t+c2*dt),tmp,k3)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a41*k1[i]+a42*k2[i]+a43*k3[i])
  end
  f(@muladd(t+c3*dt),tmp,k4)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a51*k1[i]+a52*k2[i]+a53*k3[i]+a54*k4[i])
  end
  f(@muladd(t+c4*dt),tmp,k5)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a61*k1[i]+a62*k2[i]+a63*k3[i]+a64*k4[i]+a65*k5[i])
  end
  f(t+dt,tmp,k6)
  for i in uidx
    u[i] = @muladd uprev[i]+dt*(a71*k1[i]+a72*k2[i]+a73*k3[i]+a74*k4[i]+a75*k5[i]+a76*k6[i])
  end
  f(t+dt,u,k7)
  if integrator.opts.adaptive
    for i in uidx
      utilde[i] = @muladd uprev[i] + dt*(b1*k1[i] + b2*k2[i] + b3*k3[i] + b4*k4[i] + b5*k5[i] + b6*k6[i] + b7*k7[i])
      atmp[i] = ((utilde[i]-u[i])./@muladd(integrator.opts.abstol+max(abs(uprev[i]),abs(u[i])).*integrator.opts.reltol))
    end
    integrator.EEst = integrator.opts.internalnorm(atmp)
  end
  @pack integrator = t,dt,u,k
end

to:

function perform_step!(integrator,cache::Tsit5Cache,f=integrator.f)
  @unpack t,dt,uprev,u,k = integrator
  @unpack c1,c2,c3,c4,c5,c6,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a62,a63,a64,a65,a71,a72,a73,a74,a75,a76,b1,b2,b3,b4,b5,b6,b7 = cache.tab
  @unpack k1,k2,k3,k4,k5,k6,k7,utilde,tmp,atmp = cache
  a = dt*a21
  tmp .= @muladd uprev+a*k1
  f(@muladd(t+c1*dt),tmp,k2)
  tmp .= @muladd uprev+dt*(a31*k1+a32*k2)
  f(@muladd(t+c2*dt),tmp,k3)
  tmp .= @muladd uprev+dt*(a41*k1+a42*k2+a43*k3)
  f(@muladd(t+c3*dt),tmp,k4)
  tmp .= @muladd uprev+dt*(a51*k1+a52*k2+a53*k3+a54*k4)
  f(@muladd(t+c4*dt),tmp,k5)
  tmp .= @muladd uprev+dt*(a61*k1+a62*k2+a63*k3+a64*k4+a65*k5)
  f(t+dt,tmp,k6)
  u .= @muladd uprev+dt*(a71*k1+a72*k2+a73*k3+a74*k4+a75*k5+a76*k6)
  f(t+dt,u,k7)
  if integrator.opts.adaptive
    utilde .= @muladd uprev + dt*(b1*k1 + b2*k2 + b3*k3 + b4*k4 + b5*k5 + b6*k6 + b7*k7)
    atmp .= ((utilde.-u)./@muladd(integrator.opts.abstol+max.(abs.(uprev),abs.(u)).*integrator.opts.reltol))
    integrator.EEst = integrator.opts.internalnorm(atmp)
  end
  @pack integrator = t,dt,u,k
end

I.e. all that's changed are loops to broadcast. The input array is y0 which is length 2. For reference, the @muladd macro acts like:

println(macroexpand(:(u .= @muladd uprev+dt*(a71*k1+a72*k2+a73*k3+a74*k4+a75*k5+a76*k6))))
#u .= (muladd).(dt, (muladd).(a71, k1, (muladd).(a72, k2, (muladd).(a73, k3, (muladd).(a74, k4, (muladd).(a75, k5, a76 .* k6))))), uprev)

println(macroexpand(:(atmp .= ((utilde.-u)./@muladd(integrator.opts.abstol+max.(abs.(uprev),abs.(u)).*integrator.opts.reltol)))))
#atmp .= (utilde .- u) ./ (muladd).(max.(abs.(uprev), abs.(u)), integrator.opts.reltol, integrator.opts.abstol)

The profile is here: https://ufile.io/2lu0f

The benchmark results are using loops:

BenchmarkTools.Trial:
  memory estimate:  87.43 KiB
  allocs estimate:  3757
  --------------
  minimum time:     281.032 μs (0.00% GC)
  median time:      526.934 μs (0.00% GC)
  mean time:        475.180 μs (2.88% GC)
  maximum time:     5.601 ms (87.13% GC)
  --------------
  samples:          10000
  evals/sample:     1

and using broadcast:

BenchmarkTools.Trial:
  memory estimate:  854.34 KiB
  allocs estimate:  37976
  --------------
  minimum time:     3.246 ms (0.00% GC)
  median time:      6.185 ms (0.00% GC)
  mean time:        5.762 ms (2.40% GC)
  maximum time:     13.919 ms (26.41% GC)
  --------------
  samples:          867
  evals/sample:     1

Am I hitting some broadcasting splatting penalty or something?

Metadata

Metadata

Assignees

No one assigned

    Labels

    broadcastApplying a function over a collectionperformanceMust go faster

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions