Skip to content

Commit

Permalink
Merge pull request #110 from oloapinivad/gpt-class
Browse files Browse the repository at this point in the history
ECmean based on class structure
  • Loading branch information
oloapinivad authored Dec 20, 2024
2 parents 4d6cc2c + 894a4aa commit 6360ace
Show file tree
Hide file tree
Showing 11 changed files with 822 additions and 707 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
Unreleased is the current development version, which is currently lying in `main` branch.

- Allowing for configuration file as dictionary (#106)
- GlobalMean and PerformanceIndices classe introduced (#110)

## [v0.1.11]

Expand Down
7 changes: 4 additions & 3 deletions ecmean/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from ecmean.libs.diagnostic import Diagnostic
from ecmean.libs.support import Supporter
from ecmean.libs.units import UnitsHandler
from ecmean.global_mean import global_mean
from ecmean.performance_indices import performance_indices
from ecmean.global_mean import GlobalMean, global_mean
from ecmean.performance_indices import PerformanceIndices, performance_indices

__all__ = ["global_mean", "performance_indices", "Diagnostic", "Supporter", "UnitsHandler"]
__all__ = ["GlobalMean", "global_mean", "PerformanceIndices",
"performance_indices", "Diagnostic", "Supporter", "UnitsHandler"]
596 changes: 281 additions & 315 deletions ecmean/global_mean.py

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions ecmean/libs/diagnostic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, args):
self.resolution = getattr(args, 'resolution', '')
self.ensemble = getattr(args, 'ensemble', 'r1i1p1f1')
self.addnan = getattr(args, 'addnan', False)
self.funcname = args.funcname.split(".")[1]
self.funcname = args.funcname
self.version = version
if self.year1 == self.year2:
self.ftrend = False
Expand Down Expand Up @@ -82,11 +82,11 @@ def __init__(self, args):
self.figdir = Path(os.path.join(outputdir, 'PDF'))

# init for global mean
if self.funcname == 'global_mean':
if self.funcname == 'GlobalMean':
self.cfg_global_mean(cfg)

# init for performance indices
if self.funcname in 'performance_indices':
if self.funcname in 'PerformanceIndices':
self.cfg_performance_indices(cfg)

# setting up interface file
Expand Down
2 changes: 2 additions & 0 deletions ecmean/libs/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def get_clim_files(piclim, var, diag, season):
vvvv = str(
diag.resclmdir /
f'{stringname}_variance_{var}_{dataref}_{diag.resolution}_{datayear1}-{datayear2}.nc')
else:
raise ValueError('Climatology not supported/existing!')

return clim, vvvv

Expand Down
40 changes: 9 additions & 31 deletions ecmean/libs/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,8 @@
##################


# def is_number(s):
# """Check if input is a float type"""

# try:
# float(s)
# return True
# except ValueError:
# return False

loggy = logging.getLogger(__name__)

# def numeric_loglevel(loglevel):
# """Define the logging level """
# # log level with logging
# # currently basic definition trought the text
# numeric_level = getattr(logging, loglevel.upper(), None)
# if not isinstance(numeric_level, int):
# raise ValueError(f'Invalid log level: {loglevel}')

# return numeric_level

def set_multiprocessing_start_method():
"""Function to set the multiprocessing spawn method to fork"""
plat = platform.system()
Expand Down Expand Up @@ -163,23 +144,20 @@ def get_domain(var, face):
return domain[comp]


# def get_component(face): # unused function
# """Return a dictionary providing the domain associated with a variable
# (the interface file specifies the domain for each component instead)"""

# d = face['component']
# p = dict(zip([list(d.values())[x]['domain']
# for x in range(len(d.values()))], d.keys()))
# return p


####################
# OUTPUT FUNCTIONS #
####################

def dict_to_dataframe(varstat):
"""very clumsy conversion of the nested 3-level dictionary
to a pd.dataframe: NEED TO BE IMPROVED"""
"""
Converts a nested 3-level dictionary to a pandas DataFrame.
Parameters:
varstat (dict): Nested dictionary with 3 levels.
Returns:
pd.DataFrame: Transformed DataFrame with hierarchical keys.
"""
data_table = {}
for i in varstat.keys():
pippo = {}
Expand Down
189 changes: 169 additions & 20 deletions ecmean/libs/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,50 @@
##################

import textwrap
import logging
from matplotlib.colors import TwoSlopeNorm
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
import numpy as np
from ecmean.libs.general import dict_to_dataframe, init_mydict

loggy = logging.getLogger(__name__)

def heatmap_comparison_pi(data_dict, cmip6_dict, diag, longnames, filemap: str = None, size_model=14, **kwargs):
"""
Function to produce a heatmap - seaborn based - for Performance Indices
based on CMIP6 ratio
def heatmap_comparison_pi(relative_table, diag, filemap, size_model = 14):
"""Function to produce a heatmap - seaborn based - for Performance Indices
based on CMIP6 ratio"""
Args:
data_dict (dict): dictionary of absolute performance indices
cmip6_dict (dict): dictionary of CMIP6 performance indices
diag (object): Diagnostic object
units_list (list): list of units
filemap (str): path to save the plot
size_model (int): size of the PIs in the plot
Keyword Args:
title (str): title of the plot, overrides default title
"""

# convert output dictionary to pandas dataframe
data_table = dict_to_dataframe(data_dict)
loggy.debug("Data table")
loggy.debug(data_table)

# relative pi with re-ordering of rows
cmip6_table = dict_to_dataframe(cmip6_dict).reindex(longnames)
relative_table = data_table.div(cmip6_table)

# compute the total PI mean
relative_table.loc['Total PI'] = relative_table.mean()

# reordering columns if info is available
lll = [(x, y) for x in diag.seasons for y in diag.regions]
relative_table = relative_table[lll]
loggy.debug("Relative table")
loggy.debug(relative_table)

# defining plot size
myfield = relative_table
Expand All @@ -30,23 +63,24 @@ def heatmap_comparison_pi(relative_table, diag, filemap, size_model = 14):

thr = [0, 1, 5]
tictoc = [0, 0.25, 0.5, 0.75, 1, 2, 3, 4, 5]
title = 'CMIP6 RELATIVE PI'

if 'title' in kwargs:
title = kwargs['title']
else:
title = 'CMIP6 RELATIVE PI'
title += f" {diag.modelname} {diag.expname} {diag.year1} {diag.year2}"

# axs.subplots_adjust(bottom=0.2)
# pal = sns.diverging_palette(h_neg=130, h_pos=10, s=99, l=55, sep=3, as_cmap=True)
tot = len(myfield.columns)
sss = len(set([tup[1] for tup in myfield.columns]))
divnorm = TwoSlopeNorm(vmin=thr[0], vcenter=thr[1], vmax=thr[2])
pal = sns.color_palette("Spectral_r", as_cmap=True)
# pal = sns.diverging_palette(220, 20, as_cmap=True)
chart = sns.heatmap(myfield, norm=divnorm, cmap=pal,
cbar_kws={"ticks": tictoc, 'label': title},
ax=axs, annot=True, linewidth=0.5, fmt='.2f',
annot_kws={'fontsize': size_model, 'fontweight': 'bold'})

chart = chart.set_facecolor('whitesmoke')
axs.set_title(f'{title} {diag.modelname} {diag.expname} {diag.year1} {diag.year2}', fontsize=25)
axs.set_title(title, fontsize=25)
axs.vlines(list(range(sss, tot + sss, sss)), ymin=-1, ymax=len(myfield.index), colors='k')
axs.hlines(len(myfield.index) - 1, xmin=-1, xmax=len(myfield.columns), colors='purple', lw=2)
names = [' '.join(x) for x in myfield.columns]
Expand All @@ -56,16 +90,45 @@ def heatmap_comparison_pi(relative_table, diag, filemap, size_model = 14):
axs.figure.axes[-1].yaxis.label.set_size(15)
axs.set(xlabel=None)

if filemap is None:
filemap = 'PI4_heatmap.pdf'

# save and close
plt.savefig(filemap)
plt.cla()
plt.close()


def heatmap_comparison_gm(data_table, mean_table, std_table, diag, filemap, addnan = True,
size_model = 14, size_obs = 8):
"""Function to produce a heatmap - seaborn based - for Global Mean
based on season-averaged standard deviation ratio"""
def heatmap_comparison_gm(data_dict, mean_dict, std_dict, diag, units_list, filemap=None,
addnan=True, size_model=14, size_obs=8, **kwargs):
"""
Function to produce a heatmap - seaborn based - for Global Mean
based on season-averaged standard deviation ratio
Args:
data_dict (dict): table of model data
mean_dict (dict): table of observations
std_dict (dict): table of standard deviation
diag (dict): diagnostic object
units_list (list): list of units
filemap (str): path to save the plot
addnan (bool): add to the final plots also fields which cannot be compared against observations
size_model (int): size of the model values in the plot
size_obs (int): size of the observation values in the plot
Keyword Args:
title (str): title of the plot, overrides default title
"""

# convert the three dictionary to pandas and then add units
data_table = dict_to_dataframe(data_dict)
mean_table = dict_to_dataframe(mean_dict)
std_table = dict_to_dataframe(std_dict)
for table in [data_table, mean_table, std_table]:
table.index = table.index + ' [' + units_list + ']'

loggy.debug("Data table")
loggy.debug(data_table)

# define array
ratio = (data_table - mean_table) / std_table
Expand All @@ -78,9 +141,13 @@ def heatmap_comparison_gm(data_table, mean_table, std_table, diag, filemap, addn
# for dimension of plots
xfig = len(clean.columns)
yfig = len(clean.index)
fig, axs = plt.subplots(1, 1, sharey=True, tight_layout=True, figsize=(xfig+5, yfig+2))
_, axs = plt.subplots(1, 1, sharey=True, tight_layout=True, figsize=(xfig+5, yfig+2))

title = 'GLOBAL MEAN'
if 'title' in kwargs:
title = kwargs['title']
else:
title = 'GLOBAL MEAN'
title += f" {diag.modelname} {diag.expname} {diag.year1} {diag.year2}"

# set color range and palette
thr = 10
Expand All @@ -96,10 +163,10 @@ def heatmap_comparison_gm(data_table, mean_table, std_table, diag, filemap, addn
fmt='.2f', cmap=pal)
if addnan:
empty = np.where(clean.isna(), 0, np.nan)
empty = np.where(data_table[mask]==0, np.nan, empty)
empty = np.where(data_table[mask] == 0, np.nan, empty)
chart = sns.heatmap(empty, annot=data_table[mask], fmt='.2f',
vmin=-thr - 0.5, vmax=thr + 0.5, center=0,
annot_kws={'va': 'bottom', 'fontsize': size_model, 'color':'dimgrey'}, cbar=False,
annot_kws={'va': 'bottom', 'fontsize': size_model, 'color': 'dimgrey'}, cbar=False,
cmap=ListedColormap(['none']))
chart = sns.heatmap(clean, annot=mean_table[mask], vmin=-thr - 0.5, vmax=thr + 0.5, center=0,
annot_kws={'va': 'top', 'fontsize': size_obs, 'fontstyle': 'italic'},
Expand All @@ -108,11 +175,11 @@ def heatmap_comparison_gm(data_table, mean_table, std_table, diag, filemap, addn
empty = np.where(clean.isna(), 0, np.nan)
empty = np.where(mean_table[mask].isna(), np.nan, empty)
chart = sns.heatmap(empty, annot=mean_table[mask], vmin=-thr - 0.5, vmax=thr + 0.5, center=0,
annot_kws={'va': 'top', 'fontsize': size_obs, 'fontstyle': 'italic', 'color':'dimgrey'},
fmt='.2f', cmap=ListedColormap(['none']), cbar=False)
annot_kws={'va': 'top', 'fontsize': size_obs, 'fontstyle': 'italic', 'color': 'dimgrey'},
fmt='.2f', cmap=ListedColormap(['none']), cbar=False)

chart = chart.set_facecolor('whitesmoke')
axs.set_title(f'{title} {diag.modelname} {diag.expname} {diag.year1} {diag.year2}', fontsize=25)
axs.set_title(title, fontsize=25)
axs.vlines(list(range(sss, tot + sss, sss)), ymin=-1, ymax=len(clean.index), colors='k')
names = [' '.join(x) for x in clean.columns]
axs.set_xticks([x + .5 for x in range(len(names))], names, rotation=45, ha='right', fontsize=16)
Expand All @@ -122,7 +189,89 @@ def heatmap_comparison_gm(data_table, mean_table, std_table, diag, filemap, addn
axs.figure.axes[-1].yaxis.label.set_size(15)
axs.set(xlabel=None)

if filemap is None:
filemap = 'Global_Mean_Heatmap.pdf'

# save and close
plt.savefig(filemap)
plt.cla()
plt.close()

def prepare_clim_dictionaries_pi(data, clim, shortnames):
"""
Prepare dictionaries for plotting
Args:
data: dictionary with data
clim: dictionary with climatology
shortnames: list of shortnames
Returns:
data2plot: dictionary with data for plotting
cmip6: dictionary with CMIP6 data
longnames: list of longnames
"""

# uniform dictionaries
filt_piclim = {}
for k in clim.keys():
filt_piclim[k] = clim[k]['cmip6']
for f in ['models', 'year1', 'year2']:
del filt_piclim[k][f]

# set longname, reorganize the dictionaries
data2plot = {}
cmip6 = {}
longnames = [clim[var]['longname'] for var in shortnames]
for var in shortnames:
longname = clim[var]['longname']
data2plot[longname] = data[var]
cmip6[longname] = filt_piclim[var]

return data2plot, cmip6, longnames

def prepare_clim_dictionaries_gm(data, clim, shortnames, seasons, regions):
"""
Prepare dictionaries for global mean plotting
Args:
data: dictionary with the data
clim: dictionary with the climatology
shortnames: list of shortnames
seasons: list of seasons
regions: list of regions
Returns:
obsmean: dictionary with the mean
obsstd: dictionary with the standard deviation
data2plot: dictionary with the data to plot
units_list: list of units
"""

# loop on the variables to create obsmean and obsstd
obsmean = {}
obsstd = {}
for var in shortnames:
gamma = clim[var]

# extract from yaml table for obs mean and standard deviation
mmm = init_mydict(seasons, regions)
sss = init_mydict(seasons, regions)
# if we have all the obs/std available
if isinstance(gamma['obs'], dict):
for season in seasons:
for region in regions:
mmm[season][region] = gamma['obs'][season][region]['mean']
sss[season][region] = gamma['obs'][season][region]['std']
# if only global observation is available
else:
mmm['ALL']['Global'] = gamma['obs']
obsmean[gamma['longname']] = mmm
obsstd[gamma['longname']] = sss

# set longname, get units
data2plot = {}
units_list = []
for var in shortnames:
data2plot[clim[var]['longname']] = data[var]
units_list.append(clim[var]['units'])

return obsmean, obsstd, data2plot, units_list
Loading

0 comments on commit 6360ace

Please sign in to comment.