Skip to content
Merged

Dev #173

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
dbbfe33
autopush
SCiarella Feb 18, 2025
f611572
Add YAMLloader for attentioncnn
SCiarella Feb 25, 2025
b242d5d
autopush
SCiarella Feb 25, 2025
f015d81
Add CNO loader
SCiarella Feb 26, 2025
2631b7f
autopush
SCiarella Mar 5, 2025
0e429b5
autopush
SCiarella Mar 12, 2025
16b74f9
autopush
SCiarella Mar 12, 2025
633584a
autopush
SCiarella Mar 12, 2025
f0cbb69
Merge branch 'main' into dev
SCiarella Mar 12, 2025
6f95e10
autopush
SCiarella Mar 14, 2025
afb1ee1
Merge branch 'main' into dev
SCiarella Mar 14, 2025
4763b1d
Merge branch 'main' into dev
SCiarella Mar 20, 2025
7a2f01c
autopush
SCiarella Mar 20, 2025
8d28622
Merge branch 'main' into dev
SCiarella Mar 20, 2025
710fd35
autopush
SCiarella Mar 26, 2025
5b184c8
Merge branch 'main' into dev
SCiarella Mar 26, 2025
9b9434c
autopush
SCiarella Mar 27, 2025
3f49058
Merge branch 'main' into dev
SCiarella Mar 27, 2025
f4bbc6f
autopush
SCiarella Apr 7, 2025
e19980f
Add pre-commit
SCiarella Apr 7, 2025
b1ae853
Bump Plots
SCiarella Apr 7, 2025
d38ecd8
Bump down Zygote
SCiarella Apr 7, 2025
84286fe
autopush
SCiarella Apr 7, 2025
6e4a4b8
autopush
SCiarella Apr 7, 2025
fa3be6b
Bump CairoMakie
SCiarella Apr 7, 2025
b74debb
Merge branch 'main' into dev
SCiarella Apr 7, 2025
a8fa41a
autopush
SCiarella Apr 7, 2025
8d2b15d
Merge branch 'main' into dev
SCiarella Apr 7, 2025
8e7b156
autopush
SCiarella Apr 14, 2025
939daba
autopush
SCiarella Apr 14, 2025
15e1aea
autopush
SCiarella Apr 14, 2025
d081727
autopush
SCiarella Apr 14, 2025
53957c0
autopush
SCiarella Apr 14, 2025
5854fc9
autopush
SCiarella Apr 15, 2025
c868505
autopush
SCiarella Apr 15, 2025
f3859fb
Merge branch 'main' into dev
SCiarella Apr 15, 2025
df7927f
autopush
SCiarella Apr 15, 2025
4a46a02
Merge branch 'main' into dev
SCiarella Apr 15, 2025
30c221d
autopush
SCiarella Apr 28, 2025
c464ac0
autopush
SCiarella Apr 29, 2025
e577e9b
Update to INS v3
SCiarella Apr 29, 2025
44ddd11
fix test
SCiarella Apr 29, 2025
e199f49
Comply with INS on CPU backend
SCiarella Apr 29, 2025
bc089d7
Up Zygote
SCiarella Apr 29, 2025
5aaace8
Merge branch 'main' into dev
SCiarella Apr 29, 2025
16854d4
Controllable sensealg
SCiarella May 9, 2025
d4a6e87
autopush
SCiarella May 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/loss/loss_posteriori.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,15 @@ function create_loss_post_lux(
end
kwargs = (; kwargs..., dt = dt)
end
if !(:sensealg in keys(kwargs)) && (kwargs[:sensealg] == nothing)
kwargs = (;
(k => v
for (k, v) in pairs(kwargs) if !(k == :sensealg && v == nothing))...)
@info "-----------------------------"
@warn kwargs
@info "-----------------------------"
end

function get_tspan(t)
# To avoid problems with SciMLBase.promote_tspan,
# we have to return t_span as a tuple on the CPU
Expand All @@ -202,7 +211,6 @@ function create_loss_post_lux(
prob = ODEProblem(rhs, x, tspan, ps)
pred = dev(ArrayType(solve(
prob, sciml_solver; u0 = x, p = ps, adaptive = false, saveat = Array(t), kwargs...)))
#prob, sciml_solver; u0 = x, p = ps, adaptive = false, saveat = collect(t), kwargs...)))
# remember that the first element of pred is the initial condition (SciML)
return sum(
abs2, y[griddims..., :, 1:(size(pred, 4) - 1)] - pred[griddims..., :, 2:end]) /
Expand Down
127 changes: 127 additions & 0 deletions test/test_gpu_sensealg.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
using Test
using Random: Random
using IncompressibleNavierStokes: IncompressibleNavierStokes as INS
using JLD2: load, @save
using CoupledNODE: cnn, train, create_loss_post_lux
NS = Base.get_extension(CoupledNODE, :NavierStokes)
using DifferentialEquations: ODEProblem, solve, Tsit5
using ComponentArrays: ComponentArray
using Lux: Lux
using CUDA
#using CUDSS # Warning: loading CUDSS without a CUDA device breaks the CI
using cuDNN
using LuxCUDA
using Adapt
using Optimization: Optimization
using OptimizationOptimisers: OptimizationOptimisers
using SciMLSensitivity

# Define the test set
NON_supported = [ForwardDiffSensitivity(), ZygoteAdjoint(), TrackerAdjoint(),
ReverseDiffAdjoint(), QuadratureAdjoint()]
sensealgs = [nothing, GaussAdjoint(), BacksolveAdjoint(), InterpolatingAdjoint()]
for SENSEALG_i in sensealgs
@testset "Sensealg (GPU) with $SENSEALG_i" begin
if !CUDA.functional()
@testset "CUDA not available" begin
@test true
end
return
end

T = Float32
rng = Random.Xoshiro(123)
ig = 1 # index of the LES grid to use.

@test CUDA.functional() # Check that CUDA is available

# Load the data
data = load("test_data/data_train.jld2", "data_train")
params = load("test_data/params_data.jld2", "params")

# Use gpu device
backend = CUDABackend()
CUDA.allowscalar(false)
device = x -> adapt(CuArray{T}, x)

# Build LES setups and assemble operators
setups = map(params.nles) do nles
x = ntuple(α -> LinRange(T(0.0), T(1.0), nles + 1), params.D)
INS.Setup(; x = x, Re = params.Re, backend = backend)
end

# A posteriori io_arrays
io_post = NS.create_io_arrays_posteriori(data, setups, device)

# Example of dimensions and how to operate with io_arrays_posteriori
(n, _, dim, samples, nsteps) = size(io_post[ig].u) # (nles, nles, D, samples, tsteps+1)
(samples, nsteps) = size(io_post[ig].t)

# Create dataloader containing trajectories with the specified nunroll
nunroll = 5
dataloader_posteriori = NS.create_dataloader_posteriori(
io_post[ig]; nunroll = nunroll, rng = rng, device = device)
train_data_post = dataloader_posteriori()

# Load the test data
test_data = load("test_data/data_test.jld2", "data_test")
test_io_post = NS.create_io_arrays_posteriori(test_data, setups)

u = train_data_post[1]
d = D = setups[1].grid.dimension()
N = size(u, 1)

# Define the CNN layers
closure, θ,
st = cnn(;
T = T,
D = D,
data_ch = D,
radii = [3, 3],
channels = [2, 2],
activations = [tanh, identity],
use_bias = [false, false],
rng = rng,
use_cuda = true
)

# Test and trigger the model
test_output = Lux.apply(closure, u, θ, st)[1]

# Define the right hand side of the ODE
dudt_nn2 = NS.create_right_hand_side_with_closure(
setups[ig], INS.psolver_spectral(setups[ig]), closure, st)

# Define the loss (a-posteriori)
train_data_posteriori = dataloader_posteriori()
loss_posteriori_lux = create_loss_post_lux(
dudt_nn2; sciml_solver = Tsit5(), use_cuda = true, sensealg = SENSEALG_i)
loss_value = loss_posteriori_lux(closure, θ, st, train_data_posteriori)
@test isfinite(loss_value[1]) # Check that the loss value is finite

# Callback function
callbackstate_val,
callback_val = NS.create_callback(
dudt_nn2, θ, test_io_post[ig], loss_posteriori_lux, st, nunroll = 3 * nunroll,
rng = rng, do_plot = false, plot_train = false, device = device)
θ_posteriori = θ

# Training via Lux
lux_result, lux_t,
lux_mem,
_ = @timed train(
closure, θ_posteriori, st, dataloader_posteriori, loss_posteriori_lux;
nepochs = 10, ad_type = Optimization.AutoZygote(),
alg = OptimizationOptimisers.Adam(0.001), cpu = false, callback = nothing)

loss, tstate = lux_result
# Check that the training loss is finite
@test isfinite(loss)
@test loss < 10*loss_value[1]
@info "Training loss: $loss"

# The trained parameters at the end of the training are:
θ_posteriori = tstate.parameters
@test !isnothing(θ_posteriori) # Check that the trained parameters are not nothing
end
end
Loading