Skip to content

Commit

Permalink
make tests more robust to dot product implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ks905383 committed Feb 13, 2024
1 parent 0075c7e commit 2a5d31c
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,17 @@ def test_aggregate_with_weights(ds=ds):


# Get pixel overlaps
wm = get_pixel_overlaps(gdf,pix_agg)
wm_for = get_pixel_overlaps(gdf,pix_agg,impl='for_loop')
wm_dot = get_pixel_overlaps(gdf,pix_agg,impl='dot_product')

# Get aggregate
agg = aggregate(ds,wm)
agg_for = aggregate(ds,wm_for,impl='for_loop')
agg_dot = aggregate(ds,wm_dot,impl='dot_product')

# Since the "test" for the input ds has [0,2] for the two
# equatorial pixels, the average should just be 1.0
assert np.allclose([v for v in agg.agg.test.values],1.0)
assert np.allclose([v for v in agg_for.agg.test.values],1.0)
assert np.allclose([v for v in agg_dot.agg.test.values],1.0)



Expand All @@ -467,13 +470,16 @@ def test_aggregate_with_mismatched_grid():


# Get pixel overlaps
wm = get_pixel_overlaps(gdf,pix_agg)
wm_for = get_pixel_overlaps(gdf,pix_agg,impl='for_loop')
wm_dot = get_pixel_overlaps(gdf,pix_agg,impl='dot_product')

# Get aggregate
agg = aggregate(ds,wm)
agg_for = aggregate(ds,wm_for,impl='for_loop')
agg_dot = aggregate(ds,wm_dot,impl='dot_product')

# On change in rtol, see note in test_aggregate_basic
assert np.allclose([v for v in agg.agg.test.values],1.4999,rtol=1e-4)
assert np.allclose([v for v in agg_for.agg.test.values],1.4999,rtol=1e-4)
assert np.allclose([v for v in agg_dot.agg.test.values],1.4999,rtol=1e-4)


# Should probably test multiple polygons just to be sure...
Expand Down Expand Up @@ -501,16 +507,19 @@ def test_aggregate_with_all_nans():


# Get pixel overlaps
wm = get_pixel_overlaps(gdf,pix_agg)
wm_for = get_pixel_overlaps(gdf,pix_agg,impl='for_loop')
wm_dot = get_pixel_overlaps(gdf,pix_agg,impl='dot_product')

# Get aggregate
agg = aggregate(ds,wm)
agg_for = aggregate(ds,wm_for,impl='for_loop')
agg_dot = aggregate(ds,wm_dot,impl='dot_product')

# Should only return nan
# (this is not a great assert - but agg.agg.test[0] comes out as [array(nan)],
# which... I'm not entirely sure how to reproduce. It quaks like a single nan,
# but it's unclear to me how to get it to work)
assert np.all([np.isnan(k) for k in agg.agg.test])
assert np.all([np.isnan(k) for k in agg_for.agg.test])
assert np.all([np.isnan(k) for k in agg_dot.agg.test])

def test_aggregate_with_some_nans():
ds = xr.Dataset({'test':(['lon','lat'],np.array([[np.nan,1],[2,np.nan]])),
Expand All @@ -530,13 +539,16 @@ def test_aggregate_with_some_nans():


# Get pixel overlaps
wm = get_pixel_overlaps(gdf,pix_agg)
wm_for = get_pixel_overlaps(gdf,pix_agg,impl='for_loop')
wm_dot = get_pixel_overlaps(gdf,pix_agg,impl='dot_product')

# Get aggregate
agg = aggregate(ds,wm)
agg_for = aggregate(ds,wm_for,impl='for_loop')
agg_dot = aggregate(ds,wm_dot,impl='dot_product')

# Should be 1.5; with one pixel valued 1, one pixel valued 2.
assert np.allclose([agg.agg.test[0]],1.5,rtol=1e-4)
assert np.allclose([agg_for.agg.test[0]],1.5,rtol=1e-4)
assert np.allclose([agg_dot.agg.test[0]],1.5,rtol=1e-4)



Expand Down

0 comments on commit 2a5d31c

Please sign in to comment.