Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 14 additions & 11 deletions ensemble/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from tqdm.auto import tqdm
from p_tqdm import p_map
import logging
from geostat.decomp import Cholesky # Making realizations

# Internal imports
import pipt.misc_tools.analysis_tools as at
from geostat.decomp import Cholesky # Making realizations
from pipt.misc_tools import cov_regularization
from pipt.misc_tools import wavelet_tools as wt
from misc import read_input_csv as rcsv
Expand Down Expand Up @@ -166,6 +167,7 @@ def _ext_ml_info(self):
if not self.ML_error_corr == 'none':
# options for error_comp_scheme are: once, ens, sep
self.error_comp_scheme = self.keys_en['multilevel'][i][2]
self.ML_corr_done = False

def _ext_prior_info(self):
"""
Expand Down Expand Up @@ -743,18 +745,19 @@ def calc_ml_prediction(self, input_state=None):
self.pred_data = np.array(ml_pred_data).T.tolist()

if hasattr(self,'treat_modeling_error'):
self.treat_modeling_error(self.iteration)
self.treat_modeling_error()

return success

def treat_modeling_error(self,iteration):
def treat_modeling_error(self):
if not self.ML_error_corr=='none':
if self.error_comp_scheme=='sep':
self.calc_modeling_error_sep()
self.address_ML_error(iteration)
self.address_ML_error()
elif self.error_comp_scheme=='once':
if iteration==0:
if not self.ML_corr_done:
self.calc_modeling_error_ens()
self.ML_corr_done = True
self.address_ML_error()
elif self.error_comp_scheme=='ens':
self.calc_modeling_error_ens()
Expand Down Expand Up @@ -866,12 +869,12 @@ def calc_modeling_error_sep(self):

def calc_modeling_error_ens(self):
#assim_step = 0
#tot_pred = []
#for level in range(self.tot_level):
# obs_data_vector, pred = at.aug_obs_pred_data(self.obs_data, [time_dat[level] for time_dat in self.pred_data],
# self.assim_index, self.list_datatypes) # get some data
tot_pred = []
for level in range(self.tot_level):
obs_data_vector, pred = at.aug_obs_pred_data(self.obs_data, [time_dat[level] for time_dat in self.pred_data],
self.assim_index, self.list_datatypes) # get some data
# pred = self.Dns_mat[level] * pred
# tot_pred.append(pred)
tot_pred.append(pred)
#if not isinstance(self.assim_index, list):
# self.assim_index = [self.assim_index]

Expand Down Expand Up @@ -966,7 +969,7 @@ def calc_modeling_error_ens(self):

#self.address_ML_error(self.iteration)

def address_ML_error(self,iteration):
def address_ML_error(self):
tot_pred = []
for level in range(self.tot_level):
obs_data_vector, pred = at.aug_obs_pred_data(self.obs_data, [time_dat[level] for time_dat in self.pred_data],
Expand Down