Skip to content

Commit

Permalink
Fix issue where weights were not correctly added to ds
Browse files Browse the repository at this point in the history
If the grids of `weights` and `ds` were close but not equal (floating point errors), then `weights` is not correctly added to `ds`, causing `weights` to be entirely 0s.

This has been fixed by explicitly setting the weights grid to the ds grid if their grids are `np.allclose()`.

Warnings have been added if the resultant homogenized `weights` array is all nans or all 0s.

`xa.normalize()` has been updated to ensure that if all the input vector is 0, then an output vector of all nans is returned.

New tests of `xa.normalize()` have been added to ensure it is more robust in the future.

Small fixes (changes of "|" and "&" to "or" and "and" in if statements, mainly)
  • Loading branch information
ks905383 committed Feb 13, 2024
1 parent 2d3b5b5 commit 0075c7e
Show file tree
Hide file tree
Showing 6 changed files with 73 additions and 22 deletions.
18 changes: 18 additions & 0 deletions tests/test_auxfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,24 @@ def test_normalize():
# I'm pretty sure I only use this with numpy arrays - but double-check
assert np.allclose(normalize(np.array([1,1])),np.array([0.5,0.5]))

def test_normalize_all0s():
# Should return a vector of nans if all elements of the input vector are 0
test_vec = np.array([0,0])

norm_vec = normalize(test_vec)

assert np.allclose(norm_vec,np.array([np.nan,np.nan]),
equal_nan=True)

def test_normalize_dropnans():
# Make sure nans are accurately dropped
test_vec = np.array([2,4,np.nan,4])

norm_vec = normalize(test_vec,drop_na=True)

assert np.allclose(norm_vec,np.array([0.2,0.4,np.nan,0.4]),
equal_nan=True)


##### fix_ds() tests #####
def test_fix_ds_null():
Expand Down
20 changes: 20 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,27 @@ def test_process_weights_regrid_weights():
# Check if weights were correctly added to ds
xr.testing.assert_allclose(ds_compare,ds_t)

def test_process_weights_close_weights():
# Make sure weights that are within `np.allclose` but not exactly
# the same grid as the input ds are correctly allocated
# Robustness against floating point differences in grids)
ds = xr.Dataset(coords={'lat':(['lat'],np.array([0,1])),
'lon':(['lon'],np.array([0,1]))})

weights = xr.DataArray(data=np.array([[0,1],[2,3]]),
dims=['lat','lon'],
coords=[np.array([0,1])+np.random.rand(2)*1e-10,
np.array([0,1])+np.random.rand(2)*1e-10])

ds_t,weights_info = process_weights(ds,weights=weights)

ds_compare = xr.Dataset({'weights':(('lat','lon'),np.array([[0,1],[2,3]]))},
coords={'lat':(['lat'],np.array([0,1])),
'lon':(['lon'],np.array([0,1])),
})

# Check if weights were correctly added to ds
xr.testing.assert_allclose(ds_compare,ds_t)


##### create_raster_polygons() tests #####
Expand Down
15 changes: 5 additions & 10 deletions tests/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,8 @@ def test_pixel_overlaps_export_and_import(ds=ds):

fn = 'wm_export_test'

# Build sample dataset
#ds = xr.Dataset({'test':(['lon','lat'],np.array([[0,1],[2,3]])),
# 'lat_bnds':(['lat','bnds'],np.array([[-0.5,0.5],[0.5,1.5]])),
# 'lon_bnds':(['lon','bnds'],np.array([[-0.5,0.5],[0.5,1.5]]))},
# coords={'lat':(['lat'],np.array([0,1])),
# 'lon':(['lon'],np.array([0,1])),
# 'bnds':(['bnds'],np.array([0,1]))})

# Add a simple weights grid
weights = xr.DataArray(data=np.array([[0.,1.],[2.,3.]]).astype(object),
weights = xr.DataArray(data=np.array([[0.,1.],[2.,3.]]),
dims=['lat','lon'],
coords=[ds.lat,ds.lon])

Expand Down Expand Up @@ -124,7 +116,10 @@ def test_pixel_overlaps_export_and_import(ds=ds):
if (type(wm_out.weights) is str) and (wm_out.weights=='nowghts'):
np.testing.assert_string_equal(wm_in.weights,wm_out.weights)
else:
pd.testing.assert_series_equal(wm_in.weights,wm_out.weights)
# `read_wm()` reads in weights as objects (see notes in relevant
# section of `read_wm()`... this shouldn't have too big
# of a consequence, but does make this test more complicated
pd.testing.assert_series_equal(wm_in.weights,wm_out.weights.astype(object))

##### clean
shutil.rmtree(fn)
Expand Down
21 changes: 14 additions & 7 deletions xagg/auxfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
def normalize(a,drop_na = False):
""" Normalizes the vector `a`
The vector `a` is divided by its sum.
The vector `a` is divided by its sum. If all non-`np.nan` elements of `a` are 0,
then all `np.nan`s are returned.
Parameters
---------------
Expand All @@ -32,14 +33,20 @@ def normalize(a,drop_na = False):
"""

if (drop_na) & (np.any(np.isnan(a))):
if (drop_na) and (np.any(np.isnan(a))):
a2 = a[~np.isnan(a)]
a2 = a2/a2.sum()
a[~np.isnan(a)] = a2
if np.all(a2.sum()==0):
# Return nans if the vector is only 0s
# (/ 0 error)
return a*np.nan

return a
else:
a2 = a2/a2.sum()
a[~np.isnan(a)] = a2

return a

elif (np.all(~np.isnan(a))) & (a.sum()>0):
elif (np.all(~np.isnan(a))) and (not np.all(a.sum()==0)):
return a/a.sum()
else:
return a*np.nan
Expand Down Expand Up @@ -241,7 +248,7 @@ def get_bnds(ds,wrap_around_thresh='dynamic',
# honestly, it *may* already work by just changing edges['lon']
# to [0,360], but it's not tested yet.

if ('lat' not in ds) | ('lon' not in ds):
if ('lat' not in ds) or ('lon' not in ds):
raise KeyError('"lat"/"lon" not found in [ds]. Make sure the '+
'geographic dimensions follow this naming convention (e.g., run `xa.fix_ds(ds)` before inputting.')

Expand Down
19 changes: 15 additions & 4 deletions xagg/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ def process_weights(ds,weights=None,target='ds',silent=False):

# Regrid, if necessary (do nothing if the grids match up to within
# floating-point precision)
if ((not ((ds.sizes['lat'] == weights.sizes['lat']) & (ds.sizes['lon'] == weights.sizes['lon']))) or
(not (np.allclose(ds.lat,weights.lat) & np.allclose(ds.lon,weights.lon)))):
if ((not ((ds.sizes['lat'] == weights.sizes['lat']) and (ds.sizes['lon'] == weights.sizes['lon']))) or
(not (np.allclose(ds.lat,weights.lat) and np.allclose(ds.lon,weights.lon)))):
# Import xesmf here to allow the code to work without it (it
# often has dependency issues and isn't necessary for many
# features of xagg)
Expand All @@ -176,9 +176,20 @@ def process_weights(ds,weights=None,target='ds',silent=False):

else:
raise KeyError(target+' is not a supported target for regridding. Choose "weights" or "ds".')

else:
# Make sure the values are actually identical, not just "close",
# otherwise assigning may not work below
weights['lat'] = ds['lat'].values
weights['lon'] = ds['lon'].values

# Add weights to ds
ds['weights'] = weights

# Add warnings
if np.isnan(ds['weights']).all():
warnings.warn('All inputted `weights` are np.nan after regridding.')
if (ds['weights'] == 0).all():
warnings.warn('All inputted `weights` are 0 after regridding.')

# Return
return ds,weights_info
Expand Down Expand Up @@ -257,7 +268,7 @@ def create_raster_polygons(ds,
"""

# Standardize inputs
# Standardize inputs (including lat/lon order)
ds = fix_ds(ds)
ds = get_bnds(ds)
#breakpoint()
Expand Down
2 changes: 1 addition & 1 deletion xagg/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
import shutil

from . auxfuncs import (normalize,fix_ds,get_bnds,subset_find)
from . auxfuncs import (fix_ds,get_bnds,subset_find)


def export_weightmap(wm_obj,fn,overwrite=False):
Expand Down

1 comment on commit 0075c7e

@ks905383
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closes #57

Please sign in to comment.