-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Gaurav Arya <gauravarya272@gmail.com>
- Loading branch information
1 parent
d3738a1
commit 753127d
Showing
6 changed files
with
110 additions
and
83 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
|
||
""" | ||
NoShow(layer) | ||
NoShow(string, layer) | ||
This alters printing (for instance at the REPL prompt) to let you hide the complexity | ||
of some part of a Flux model. It has no effect on the actual running of the model. | ||
By default it prints `NoShow(...)` instead of the given layer. | ||
If you provide a string, it prints that instead -- it can be anything, | ||
but it may make sense to print the name of a function which will | ||
re-create the same structure. | ||
# Examples | ||
```jldoctest | ||
julia> Chain(Dense(2 => 3), NoShow(Parallel(vcat, Dense(3 => 4), Dense(3 => 5))), Dense(9 => 10)) | ||
Chain( | ||
Dense(2 => 3), # 9 parameters | ||
NoShow(...), # 36 parameters | ||
Dense(9 => 10), # 100 parameters | ||
) # Total: 8 arrays, 145 parameters, 1.191 KiB. | ||
julia> pseudolayer((i,o)::Pair) = NoShow( | ||
"pseudolayer(\$i => \$o)", | ||
Parallel(+, Dense(i => o, relu), Dense(i => o, tanh)), | ||
) | ||
pseudolayer (generic function with 1 method) | ||
julia> Chain(Dense(2 => 3), pseudolayer(3 => 10), Dense(9 => 10)) | ||
Chain( | ||
Dense(2 => 3), # 9 parameters | ||
pseudolayer(3 => 10), # 80 parameters | ||
Dense(9 => 10), # 100 parameters | ||
) # Total: 8 arrays, 189 parameters, 1.379 KiB. | ||
``` | ||
""" | ||
struct NoShow{T} | ||
str::String | ||
layer::T | ||
end | ||
|
||
NoShow(layer) = NoShow("NoShow(...)", layer) | ||
|
||
Flux.@functor NoShow | ||
|
||
(no::NoShow)(x...) = no.layer(x...) | ||
|
||
Base.show(io::IO, no::NoShow) = print(io, no.str) | ||
|
||
Flux._show_leaflike(::NoShow) = true # I think this is right | ||
Flux._show_children(::NoShow) = (;) # Seems to be needed? | ||
|
||
function Base.show(io::IO, ::MIME"text/plain", m::NoShow) | ||
if get(io, :typeinfo, nothing) === nothing # e.g., top level of REPL | ||
Flux._big_show(io, m) | ||
elseif !get(io, :compact, false) # e.g., printed inside a Vector, but not a matrix | ||
Flux._layer_show(io, m) | ||
else | ||
show(io, m) | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
|
||
@testset "NoShow" begin | ||
d23 = Dense(2 => 3) | ||
d34 = Dense(3 => 4, tanh) | ||
d35 = Dense(3 => 5, relu) | ||
d910 = Dense(9 => 10) | ||
|
||
model = Chain(d23, Parallel(vcat, d34, d35), d910) | ||
m_no = Chain(d23, NoShow(Parallel(vcat, d34, NoShow("zzz", d35))), d910) | ||
|
||
@test sum(length, Flux.params(model)) == sum(length, Flux.params(m_no)) | ||
|
||
xin = randn(Float32, 2, 7) | ||
@test model(xin) ≈ m_no(xin) | ||
|
||
# gradients | ||
grad = gradient(m -> m(xin)[1], model)[1] | ||
g_no = gradient(m -> m(xin)[1], m_no)[1] | ||
|
||
@test grad.layers[2].layers[1].bias ≈ g_no.layers[2].layer.layers[1].bias | ||
@test grad.layers[2].layers[2].bias ≈ g_no.layers[2].layer.layers[2].layer.bias | ||
|
||
# printing -- see also compact.jl for another test | ||
@test !contains(string(model), "NoShow(...)") | ||
@test contains(string(m_no), "NoShow(...)") | ||
@test !contains(string(m_no), "3 => 4") | ||
end | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters