Skip to content

Commit

Permalink
Fix bad _bnds existence check and remove .keys() references
Browse files Browse the repository at this point in the history
Fixes issue in `get_bnds()` where the existence of `lat_bounds` instead of `lat_bnds` (the `xa.fix_ds()` convention) was checked for
  • Loading branch information
ks905383 authored Feb 13, 2024
1 parent 69516b3 commit c70c2e1
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions xagg/aux.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,15 @@ def fix_ds(ds,var_cipher = {'latitude':{'latitude':'lat','longitude':'lon'},

# List of variables that represent bounds
if type(ds) == xr.core.dataset.Dataset:
bnd_vars = [k for k in list(ds.keys()) if 'bnds' in k]
bnd_vars = [k for k in list(ds) if 'bnds' in k]
elif type(ds) == xr.core.dataarray.DataArray:
bnd_vars = []
else:
raise TypeError('[ds] needs to be an xarray structure (Dataset or DataArray).')

# Fix lat/lon variable names (sizes instead of dims to be compatible with both ds, da...)
if 'lat' not in ds.sizes.keys():
test_dims = [k for k in var_cipher.keys() if k in ds.sizes.keys()]
if 'lat' not in ds.sizes:
test_dims = [k for k in var_cipher if k in ds.sizes]
if len(test_dims) == 0:
raise NameError('No valid lat/lon variables found in the dataset.')
else:
Expand Down Expand Up @@ -156,7 +156,7 @@ def fix_ds(ds,var_cipher = {'latitude':{'latitude':'lat','longitude':'lon'},
if (type(ds) == xr.core.dataset.Dataset):
# Three if statements because of what I believe to be a jupyter error
# (where all three statements are evaluated instead of one at a time)
if ('lon_bnds' in ds.keys()):
if ('lon_bnds' in ds):
if (ds.lon_bnds.max()>180):
ds['lon_bnds'] = (ds['lon_bnds'] + 180) % 360 - 180

Expand Down Expand Up @@ -241,11 +241,13 @@ def get_bnds(ds,wrap_around_thresh='dynamic',
# honestly, it *may* already work by just changing edges['lon']
# to [0,360], but it's not tested yet.

if ('lat' not in ds.keys()) | ('lon' not in ds.keys()):
if ('lat' not in ds) | ('lon' not in ds):
raise KeyError('"lat"/"lon" not found in [ds]. Make sure the '+
'geographic dimensions follow this naming convention (e.g., run `xa.fix_ds(ds)` before inputting.')

if 'lat_bounds' in ds.keys():
if ('lat_bnds' in ds) and (`lon_bnds` in ds):
# `xa.fix_ds()` should rename bounds to `lat/lon_bnds`
# If bounds present, do nothing
return ds
else:
if not silent:
Expand Down Expand Up @@ -388,12 +390,12 @@ def subset_find(ds0,ds1):
"""

if 'loc' not in ds0.sizes.keys():
if 'loc' not in ds0.sizes:
ds0 = ds0.stack(loc = ('lat','lon'))
was_stacked = True
else:
was_stacked = False
#if 'loc' not in ds1.sizes.keys():
#if 'loc' not in ds1.sizes:
# ds1 = ds1.stack(loc = ('lat','lon'))

# Need a test to make sure the grid is the same. So maybe the gdf_out class
Expand Down

0 comments on commit c70c2e1

Please sign in to comment.