Skip to content

Commit d0c9261

Browse files
authored
Merge pull request tpapp#120 from devmotion/dw/namedtuple
Support NamedTuple of different ordering
2 parents 3a55e93 + d551fc1 commit d0c9261

File tree

3 files changed

+19
-5
lines changed

3 files changed

+19
-5
lines changed

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "TransformVariables"
22
uuid = "84d833dd-6860-57f9-a1a7-6da5db126cff"
33
authors = ["Tamas K. Papp <tkpapp@gmail.com>"]
4-
version = "0.8.10"
4+
version = "0.8.11"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

src/aggregation.jl

+8-4
Original file line numberDiff line numberDiff line change
@@ -391,14 +391,18 @@ end
391391

392392
function inverse_eltype(tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
393393
@unpack transformations = tt
394-
@argcheck keys(transformations) == keys(y)
395-
_inverse_eltype_tuple(values(transformations), values(y))
394+
@argcheck _same_set_of_names(transformations, y)
395+
_inverse_eltype_tuple(values(transformations), values(NamedTuple{keys(transformations)}(y)))
396396
end
397397

398398
function inverse_at!(x::AbstractVector, index, tt::TransformTuple{<:NamedTuple}, y::NamedTuple)
399399
@unpack transformations = tt
400-
@argcheck keys(transformations) == keys(y)
401-
_inverse!_tuple(x, index, values(transformations), values(y))
400+
@argcheck _same_set_of_names(transformations, y)
401+
_inverse!_tuple(x, index, values(transformations), values(NamedTuple{keys(transformations)}(y)))
402+
end
403+
404+
function _same_set_of_names(x::NamedTuple, y::NamedTuple)
405+
return length(x) == length(y) && Base.structdiff(x, y) === (;)
402406
end
403407

404408
function _domain_label(t::TransformTuple, index::Int)

test/runtests.jl

+10
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,16 @@ end
316316
@test x x′
317317
end
318318

319+
@testset "different order and superset of NamedTuple" begin
320+
# test for #100
321+
t = as((a = asℝ, b = asℝ))
322+
@test @inferred(inverse(t, (a = 1.0, b = 2.0))) == [1.0, 2.0]
323+
@test @inferred(inverse(t, (b = 2.0, a = 1.0))) == [1.0, 2.0]
324+
@test_throws ArgumentError inverse(t, (; a = 1.0))
325+
@test_throws ArgumentError inverse(t, (a = 1.0, b = 2.0, c = 3.0))
326+
@test_throws ArgumentError inverse(t, (a = 1.0, c = 2.0))
327+
end
328+
319329
####
320330
#### log density correctness checks
321331
####

0 commit comments

Comments
 (0)