Skip to content

Conversation

@penelopeysm
Copy link
Member

@penelopeysm penelopeysm commented Oct 23, 2025

Closes #1081

I traced the error all the way back to unflatten, and the problem is partly related to #906, but it's perhaps even more subtle than that.

unflatten DOES attempt to change the eltypes of the accumulators based on the parameters. In other words it does actually try its best to work for all Real types. (Not doing so would cause issues with ForwardDiff, and we don't see that. So clearly something was working.)

accs = map(
acc -> convert_eltype(float_type_with_fallback(eltype(x)), acc), copy(getaccs(vi))
)

However, it does this specifically using float_type_with_fallback. And for types T <: Real, we currently have the definition

DynamicPPL.jl/src/utils.jl

Lines 777 to 784 in 1b159a6

"""
float_type_with_fallback(T::DataType)
Return `float(T)` if possible; otherwise return `float(Real)`.
"""
float_type_with_fallback(::Type) = float(Real)
float_type_with_fallback(::Type{Union{}}) = float(Real)
float_type_with_fallback(::Type{T}) where {T<:Real} = float(T)

Now then the reason why this Just Works with ForwardDiff is because

julia> import ForwardDiff; import SparseConnectivityTracer as SCT

julia> x = ForwardDiff.Dual(0.5); float(typeof(x))
ForwardDiff.Dual{Nothing, Float64, 0}

But for SCT we have that

julia> x = SCT.Dual{Float64,SCT.HessianTracer{Int64, BitSet, Dict{Int64, BitSet},SCT.NotShared}}(0.5); float(typeof(x))
Float64

This PR therefore makes the smallest possible change to get this to work for both dual types.

It also adds more methods to handle

  • Integers: we don't want our log prob accumulators to use integers.
  • Real: we want to concretise, this somehow preserves type stability and I don't know how, but I think there's some dark magic with eltype, TypeWrap and matchingvalue going on here.

These methods mean that the old behaviour of float(T) for such types is preserved.

A note on ForwardDiff

One might think that this PR causes AD to break when evaluating with Vector{Int} parameters, because the old behaviour was that ForwardDiff.Dual{tag,Int} would be converted to ForwardDiff.Dual{tag,Float64}, and the new behaviour in this PR leaves it untouched.

As it turns out, ForwardDiff can't be used with Vector{Int} parameters anyway, because the output gradient is stored in a Vector{Int} and although DI correctly calculates gradients of [0.5, 0.5], it will error when trying to insert those gradients back into the Vector{Int}.

import DifferentiationInterface as DI
import ForwardDiff
f(x) = 0.5 * (x[1] + x[2])
# errors
DI.value_and_gradient(f, DI.AutoForwardDiff(), [1, 1])

so that case can be safely ignored.

Comment on lines 370 to 372
end
end

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No change to the code here, just removing the redundant VERSION check since our min bound is 1.10.

@github-actions
Copy link
Contributor

github-actions bot commented Oct 23, 2025

Benchmark Report for Commit 9f89ee7

Computer Information

Julia Version 1.11.7
Commit f2b3dbda30a (2025-09-08 12:10 UTC)
Build Info:
  Official https://julialang.org/ release
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 4 × AMD EPYC 7763 64-Core Processor
  WORD_SIZE: 64
  LLVM: libLLVM-16.0.6 (ORCJIT, znver3)
Threads: 1 default, 0 interactive, 1 GC (on 4 virtual cores)

Benchmark Results

┌───────────────────────┬───────┬─────────────┬───────────────────┬────────┬────────────────┬─────────────────┐
│                 Model │   Dim │  AD Backend │           VarInfo │ Linked │ t(eval)/t(ref) │ t(grad)/t(eval) │
├───────────────────────┼───────┼─────────────┼───────────────────┼────────┼────────────────┼─────────────────┤
│ Simple assume observe │     1 │ forwarddiff │             typed │  false │            6.8 │             1.8 │
│           Smorgasbord │   201 │ forwarddiff │             typed │  false │          742.8 │            42.4 │
│           Smorgasbord │   201 │ forwarddiff │ simple_namedtuple │   true │          439.4 │            52.0 │
│           Smorgasbord │   201 │ forwarddiff │           untyped │   true │          837.9 │            35.7 │
│           Smorgasbord │   201 │ forwarddiff │       simple_dict │   true │         6753.3 │            30.6 │
│           Smorgasbord │   201 │ reversediff │             typed │   true │          764.5 │            55.9 │
│           Smorgasbord │   201 │    mooncake │             typed │   true │          759.0 │             6.2 │
│           Smorgasbord │   201 │      enzyme │             typed │   true │          945.3 │             3.8 │
│    Loop univariate 1k │  1000 │    mooncake │             typed │   true │         4017.2 │             5.8 │
│       Multivariate 1k │  1000 │    mooncake │             typed │   true │         1013.4 │             9.1 │
│   Loop univariate 10k │ 10000 │    mooncake │             typed │   true │        44545.9 │             5.2 │
│      Multivariate 10k │ 10000 │    mooncake │             typed │   true │         8627.7 │            10.4 │
│               Dynamic │    10 │    mooncake │             typed │   true │          129.6 │            11.6 │
│              Submodel │     1 │    mooncake │             typed │   true │            9.4 │             6.3 │
│                   LDA │    12 │ reversediff │             typed │   true │         1037.7 │             2.0 │
└───────────────────────┴───────┴─────────────┴───────────────────┴────────┴────────────────┴─────────────────┘

@codecov
Copy link

codecov bot commented Oct 23, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 81.31%. Comparing base (1b159a6) to head (9f89ee7).
⚠️ Report is 3 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1088      +/-   ##
==========================================
+ Coverage   81.06%   81.31%   +0.25%     
==========================================
  Files          40       40              
  Lines        3749     3751       +2     
==========================================
+ Hits         3039     3050      +11     
+ Misses        710      701       -9     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@penelopeysm penelopeysm marked this pull request as draft October 24, 2025 09:11
@penelopeysm penelopeysm requested a review from mhauru October 27, 2025 11:00
@penelopeysm
Copy link
Member Author

@mhauru This is not quite the right fix (it still does not respect the logp precision in the accumulator), but it will fix CI. Not sure what you think. I'm rather on the fence about going forward with this.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old thing was hacky, and this thing is hacky, but hacky in a way that makes more things work, so it seems like an improvement to me. Any particular reason to decide not to do this? I assume the conclusion still holds that the proper fix needs special casing on AD tracer types.

Co-authored-by: Markus Hauru <markus@mhauru.org>
@penelopeysm
Copy link
Member Author

penelopeysm commented Oct 27, 2025

Yeah the correct fix still needs overloading on specific types. The main reason why I'm hesitant is technical debt, plus the knowledge that once this fix is in, any motivation I have to fix it properly will vanish.

@github-actions
Copy link
Contributor

DynamicPPL.jl documentation for PR #1088 is available at:
https://TuringLang.github.io/DynamicPPL.jl/previews/PR1088/

@penelopeysm
Copy link
Member Author

(Also, the original issue might be fixed by adrhill/SparseConnectivityTracer.jl#279)

@mhauru
Copy link
Member

mhauru commented Oct 27, 2025

Not sure if this adds much technical debt, I feel like this is about as hacky as what was in place. A tiny bit more complicated, yes. Happy either way, to merge this or wait for the upstream fix, as long as the latter doesn't take very long.

@adrhill
Copy link

adrhill commented Oct 27, 2025

as long as the latter doesn't take very long.

I will tag an SCT 1.1.2 release today.

@mhauru
Copy link
Member

mhauru commented Oct 27, 2025

That definitely counts as not very long. :) Thank you!

@adrhill
Copy link

adrhill commented Oct 27, 2025

The release is tagged and should be available soon: JuliaRegistries/General#141344

@penelopeysm
Copy link
Member Author

CI on main passes now! Thank you @adrhill 😊

@yebai yebai deleted the py/tsvi branch October 30, 2025 17:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Unable to compute sparse Hessians with ThreadSafeVarInfo

4 participants