Skip to content

Performance issues with fusing 12+ broadcasts #22255

Closed

Description

Updated OP

While working through the issue, @jebej identified that the problem is fusing 12+ broadcasts in this comment (#22255 (comment)) which contains an MWE.

Original OP

In OrdinaryDiffEq.jl, I see a 10x performance regression due to using broadcast. With the testing code:

const ζ = 0.5
const ω₀ = 10.0

using OrdinaryDiffEq, DiffEqBase

const y₀ = Float64[0., sqrt(1-ζ^2)*ω₀]
const A = 1
const ϕ = 0


function f(t::Float64)
    α = sqrt(1-ζ^2)*ω₀
    x = A*exp(-ζ*ω₀*t)*sin*t + ϕ)
    p = A*exp(-ζ*ω₀*t)*(-ζ*ω₀*sin*t + ϕ) + α*cos*t + ϕ))
    return [x,p]
end

function df(t::Float64, y::Vector{Float64}, dy::Vector{Float64})
    dy[1] = y[2]
    dy[2] = -2*ζ*ω₀*y[2] - ω₀^2*y[1]
    return nothing
end

const T = [0.,10.]
using BenchmarkTools
prob = ODEProblem(df,y₀,(T[1],T[2]))
@benchmark init($prob,Tsit5(),dense=false,dt=1/10)
@benchmark solve($prob,Tsit5(),dense=false,dt=1/10)

using ProfileView

@profile for i in 1:1000; solve(prob,Tsit5(),dense=false,dt=1/10); end
ProfileView.view()

I get a 10x regression by changing the inner loop from:

function perform_step!(integrator,cache::Tsit5Cache,f=integrator.f)
  @unpack t,dt,uprev,u,k = integrator
  uidx = eachindex(integrator.uprev)
  @unpack c1,c2,c3,c4,c5,c6,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a62,a63,a64,a65,a71,a72,a73,a74,a75,a76,b1,b2,b3,b4,b5,b6,b7 = cache.tab
  @unpack k1,k2,k3,k4,k5,k6,k7,utilde,tmp,atmp = cache
  a = dt*a21
  for i in uidx
    tmp[i] = @muladd uprev[i]+a*k1[i]
  end
  f(@muladd(t+c1*dt),tmp,k2)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a31*k1[i]+a32*k2[i])
  end
  f(@muladd(t+c2*dt),tmp,k3)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a41*k1[i]+a42*k2[i]+a43*k3[i])
  end
  f(@muladd(t+c3*dt),tmp,k4)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a51*k1[i]+a52*k2[i]+a53*k3[i]+a54*k4[i])
  end
  f(@muladd(t+c4*dt),tmp,k5)
  for i in uidx
    tmp[i] = @muladd uprev[i]+dt*(a61*k1[i]+a62*k2[i]+a63*k3[i]+a64*k4[i]+a65*k5[i])
  end
  f(t+dt,tmp,k6)
  for i in uidx
    u[i] = @muladd uprev[i]+dt*(a71*k1[i]+a72*k2[i]+a73*k3[i]+a74*k4[i]+a75*k5[i]+a76*k6[i])
  end
  f(t+dt,u,k7)
  if integrator.opts.adaptive
    for i in uidx
      utilde[i] = @muladd uprev[i] + dt*(b1*k1[i] + b2*k2[i] + b3*k3[i] + b4*k4[i] + b5*k5[i] + b6*k6[i] + b7*k7[i])
      atmp[i] = ((utilde[i]-u[i])./@muladd(integrator.opts.abstol+max(abs(uprev[i]),abs(u[i])).*integrator.opts.reltol))
    end
    integrator.EEst = integrator.opts.internalnorm(atmp)
  end
  @pack integrator = t,dt,u,k
end

to:

function perform_step!(integrator,cache::Tsit5Cache,f=integrator.f)
  @unpack t,dt,uprev,u,k = integrator
  @unpack c1,c2,c3,c4,c5,c6,a21,a31,a32,a41,a42,a43,a51,a52,a53,a54,a61,a62,a63,a64,a65,a71,a72,a73,a74,a75,a76,b1,b2,b3,b4,b5,b6,b7 = cache.tab
  @unpack k1,k2,k3,k4,k5,k6,k7,utilde,tmp,atmp = cache
  a = dt*a21
  tmp .= @muladd uprev+a*k1
  f(@muladd(t+c1*dt),tmp,k2)
  tmp .= @muladd uprev+dt*(a31*k1+a32*k2)
  f(@muladd(t+c2*dt),tmp,k3)
  tmp .= @muladd uprev+dt*(a41*k1+a42*k2+a43*k3)
  f(@muladd(t+c3*dt),tmp,k4)
  tmp .= @muladd uprev+dt*(a51*k1+a52*k2+a53*k3+a54*k4)
  f(@muladd(t+c4*dt),tmp,k5)
  tmp .= @muladd uprev+dt*(a61*k1+a62*k2+a63*k3+a64*k4+a65*k5)
  f(t+dt,tmp,k6)
  u .= @muladd uprev+dt*(a71*k1+a72*k2+a73*k3+a74*k4+a75*k5+a76*k6)
  f(t+dt,u,k7)
  if integrator.opts.adaptive
    utilde .= @muladd uprev + dt*(b1*k1 + b2*k2 + b3*k3 + b4*k4 + b5*k5 + b6*k6 + b7*k7)
    atmp .= ((utilde.-u)./@muladd(integrator.opts.abstol+max.(abs.(uprev),abs.(u)).*integrator.opts.reltol))
    integrator.EEst = integrator.opts.internalnorm(atmp)
  end
  @pack integrator = t,dt,u,k
end

I.e. all that's changed are loops to broadcast. The input array is y0 which is length 2. For reference, the @muladd macro acts like:

println(macroexpand(:(u .= @muladd uprev+dt*(a71*k1+a72*k2+a73*k3+a74*k4+a75*k5+a76*k6))))
#u .= (muladd).(dt, (muladd).(a71, k1, (muladd).(a72, k2, (muladd).(a73, k3, (muladd).(a74, k4, (muladd).(a75, k5, a76 .* k6))))), uprev)

println(macroexpand(:(atmp .= ((utilde.-u)./@muladd(integrator.opts.abstol+max.(abs.(uprev),abs.(u)).*integrator.opts.reltol)))))
#atmp .= (utilde .- u) ./ (muladd).(max.(abs.(uprev), abs.(u)), integrator.opts.reltol, integrator.opts.abstol)

The profile is here: https://ufile.io/2lu0f

The benchmark results are using loops:

BenchmarkTools.Trial:
  memory estimate:  87.43 KiB
  allocs estimate:  3757
  --------------
  minimum time:     281.032 μs (0.00% GC)
  median time:      526.934 μs (0.00% GC)
  mean time:        475.180 μs (2.88% GC)
  maximum time:     5.601 ms (87.13% GC)
  --------------
  samples:          10000
  evals/sample:     1

and using broadcast:

BenchmarkTools.Trial:
  memory estimate:  854.34 KiB
  allocs estimate:  37976
  --------------
  minimum time:     3.246 ms (0.00% GC)
  median time:      6.185 ms (0.00% GC)
  mean time:        5.762 ms (2.40% GC)
  maximum time:     13.919 ms (26.41% GC)
  --------------
  samples:          867
  evals/sample:     1

Am I hitting some broadcasting splatting penalty or something?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    broadcastApplying a function over a collectionperformanceMust go faster

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions