Skip to content

Commit

Permalink
Merge pull request #10 from Team-RADDISH/dg_patch
Browse files Browse the repository at this point in the history
SPEEDY: Patch to fix n_tasks
  • Loading branch information
DanGiles authored Nov 27, 2024
2 parents 9014f57 + b2c25f1 commit 2ba0d22
Showing 1 changed file with 19 additions and 8 deletions.
27 changes: 19 additions & 8 deletions SPEEDY/model/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ function ParticleDA.sample_initial_state!(
state::AbstractVector{T},
model_data::ModelData,
rng::Random.AbstractRNG,
ind=1,
) where T
# Read in arbitrary nature run files for the initial conditions
dummy_date = step_ens(model_data.model_params.ensDate, model_data.model_params.Hinc, rng)
Expand Down Expand Up @@ -358,14 +359,22 @@ function get_station_grid_indices(
readline(f)
readline(f)
count = 0
ind = 1
for line in eachline(f)
if count%10 == 0
ind = 1
if nobs == 50
for line in eachline(f)
if count%10 == 0
station_grid_indices[ind, 1] = parse(Int64,split(line)[1])
station_grid_indices[ind, 2] = parse(Int64,split(line)[2])
ind += 1
end
count = count + 1
end
else
for line in eachline(f)
station_grid_indices[ind, 1] = parse(Int64,split(line)[1])
station_grid_indices[ind, 2] = parse(Int64,split(line)[2])
ind += 1
end
count = count + 1
end
end
return station_grid_indices
Expand Down Expand Up @@ -509,7 +518,7 @@ function ParticleDA.get_state_indices_correlated_to_observations(model_data::Mod
)
end

function init(model_params_dict::Dict)
function init(model_params_dict::Dict, n_tasks=1)

model_params = ParticleDA.get_params(
ModelParameters, get(model_params_dict, "speedy", Dict())
Expand Down Expand Up @@ -581,7 +590,8 @@ function ParticleDA.sample_observation_given_state!(
observation::AbstractVector{S},
state::AbstractVector{T},
model_data::ModelData,
rng::AbstractRNG
rng::AbstractRNG,
task_index=1,
) where{S, T}
ParticleDA.get_observation_mean_given_state!(observation, state, model_data)
observation .+= rand(
Expand All @@ -594,6 +604,7 @@ function ParticleDA.get_log_density_observation_given_state(
observation::AbstractVector,
state::AbstractVector,
model_data::ModelData,
task_index=1,
)
observation_mean = view(model_data.observation_buffer, :, threadid())
ParticleDA.get_observation_mean_given_state!(observation_mean, state, model_data)
Expand All @@ -604,7 +615,7 @@ function ParticleDA.get_log_density_observation_given_state(
end

function ParticleDA.update_state_deterministic!(
state::AbstractVector, d::ModelData, time_index::Int
state::AbstractVector, d::ModelData, time_index::Int, task_index=1,
)
state_fields = flat_state_to_fields(state, d.model_params)
my_rank = MPI.Comm_rank(MPI.COMM_WORLD)
Expand Down Expand Up @@ -636,7 +647,7 @@ function ParticleDA.update_state_deterministic!(
end

function ParticleDA.update_state_stochastic!(
state::AbstractVector, model::ModelData, rng::AbstractRNG
state::AbstractVector, model::ModelData, rng::AbstractRNG, task_index=1
)
# Add state noise
add_random_field!(
Expand Down

0 comments on commit 2ba0d22

Please sign in to comment.