Skip to content

Commit

Permalink
Fixing xi sampling code, and making unit test more subtle
Browse files Browse the repository at this point in the history
  • Loading branch information
erikvansebille committed Jul 25, 2024
1 parent acba3fd commit 47f50a0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
3 changes: 2 additions & 1 deletion parcels/compilation/codegenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,8 @@ def visit_Subscript(self, node):
if isinstance(node.value, FieldNode) or isinstance(node.value, VectorFieldNode):
node.ccode = node.value.__getitem__(node.slice.ccode).ccode
elif isinstance(node.value, ParticleXiYiZiTiAttributeNode):
node.ccode = f"{node.value.obj}->{node.value.attr}[pnum, {node.slice.ccode}]"
ngrid = str(self.fieldset.gridset.size if self.fieldset is not None else 1)
node.ccode = f"{node.value.obj}->{node.value.attr}[pnum*{ngrid}+{node.slice.ccode}]"
elif isinstance(node.value, IntrinsicNode):
raise NotImplementedError(f"Subscript not implemented for object type {type(node.value).__name__}")
else:
Expand Down
35 changes: 19 additions & 16 deletions tests/test_particlefile.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def Update_lon(particle, fieldset, time):
def test_write_xiyi(fieldset, mode, tmpdir):
outfilepath = tmpdir.join("pfile_xiyi.zarr")
fieldset.U.data[:] = 1 # set a non-zero zonal velocity
fieldset.add_field(Field(name='P', data=np.zeros((2, 20)), lon=np.linspace(0, 1, 20), lat=[0, 2]))
fieldset.add_field(Field(name='P', data=np.zeros((3, 20)), lon=np.linspace(0, 1, 20), lat=[-2, 0, 2]))
dt = 3600

XiYiParticle = ptype[mode].add_variables([
Expand All @@ -306,24 +306,27 @@ def SampleP(particle, fieldset, time):
if time > 5*3600:
tmp = fieldset.P[particle] # noqa

pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0], lat=[0.2], lonlatdepth_dtype=np.float64)
pset = ParticleSet(fieldset, pclass=XiYiParticle, lon=[0, 0.2], lat=[0.2, 1], lonlatdepth_dtype=np.float64)
pfile = pset.ParticleFile(name=outfilepath, outputdt=dt)
pset.execute([Get_XiYi, SampleP, AdvectionRK4], endtime=10*dt, dt=dt, output_file=pfile)
pset.execute([SampleP, Get_XiYi, AdvectionRK4], endtime=10*dt, dt=dt, output_file=pfile)

ds = xr.open_zarr(outfilepath)
pxi0 = ds['pxi0'][:].values[0].astype(np.int32)
pxi1 = ds['pxi1'][:].values[0].astype(np.int32)
lons = ds['lon'][:].values[0]
pyi = ds['pyi'][:].values[0].astype(np.int32)
lats = ds['lat'][:].values[0]

assert (pxi0[0] == 0) and (pxi0[-1] == 11) # check that particle has moved
assert np.all(pxi1[:7] == 0) # check that particle has not been sampled on grid 1 until time 6
assert np.all(pxi1[7:] > 0) # check that particle has not been sampled on grid 1 after time 6
for xi, lon in zip(pxi0[1:], lons[1:]):
assert fieldset.U.grid.lon[xi] <= lon < fieldset.U.grid.lon[xi+1]
for yi, lat in zip(pyi[1:], lats[1:]):
assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi+1]
pxi0 = ds['pxi0'][:].values.astype(np.int32)
pxi1 = ds['pxi1'][:].values.astype(np.int32)
lons = ds['lon'][:].values
pyi = ds['pyi'][:].values.astype(np.int32)
lats = ds['lat'][:].values

for p in range(pyi.shape[0]):
assert (pxi0[p, 0] == 0) and (pxi0[p, -1] == pset[p].pxi0) # check that particle has moved
assert np.all(pxi1[p, :6] == 0) # check that particle has not been sampled on grid 1 until time 6
assert np.all(pxi1[p, 6:] > 0) # check that particle has not been sampled on grid 1 after time 6
for xi, lon in zip(pxi0[p, 1:], lons[p, 1:]):
assert fieldset.U.grid.lon[xi] <= lon < fieldset.U.grid.lon[xi+1]
for xi, lon in zip(pxi1[p, 6:], lons[p, 6:]):
assert fieldset.P.grid.lon[xi] <= lon < fieldset.P.grid.lon[xi+1]
for yi, lat in zip(pyi[p, 1:], lats[p, 1:]):
assert fieldset.U.grid.lat[yi] <= lat < fieldset.U.grid.lat[yi+1]
ds.close()


Expand Down

0 comments on commit 47f50a0

Please sign in to comment.