Skip to content

Commit

Permalink
Merge pull request #1727 from OceanParcels/v/api
Browse files Browse the repository at this point in the history
API changes: `particlefile.py` and other touchups
  • Loading branch information
VeckoTheGecko authored Oct 22, 2024
2 parents 9907c84 + f03f80e commit aa716e4
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 72 deletions.
13 changes: 0 additions & 13 deletions parcels/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,19 +279,6 @@ class CStructuredGrid(Structure):
)
return self._cstruct

def lon_grid_to_target(self):
if self.lon_remapping:
self._lon = self.lon_remapping.to_target(self.lon)

def lon_grid_to_source(self):
if self.lon_remapping:
self._lon = self.lon_remapping.to_source(self.lon)

def lon_particle_to_target(self, lon):
if self.lon_remapping:
return self.lon_remapping.particle_to_target(lon)
return lon

@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def check_zonal_periodic(self, *args, **kwargs):
return self._check_zonal_periodic(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion parcels/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,7 @@ def check_fieldsets_in_kernels(self, pyfunc):
)
elif pyfunc is AdvectionAnalytical:
if self.fieldset.particlefile is not None:
self.fieldset.particlefile.analytical = True
self.fieldset.particlefile._is_analytical = True
if self._ptype.uses_jit:
raise NotImplementedError("Analytical Advection only works in Scipy mode")
if self._fieldset.U.interp_method != "cgrid_velocity":
Expand Down
6 changes: 5 additions & 1 deletion parcels/particle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,15 @@ class Variable:
"""

def __init__(self, name, dtype=np.float32, initial=0, to_write: bool | Literal["once"] = True):
self.name = name
self._name = name
self.dtype = dtype
self.initial = initial
self.to_write = to_write

@property
def name(self):
return self._name

def __get__(self, instance, cls):
if instance is None:
return self
Expand Down
8 changes: 6 additions & 2 deletions parcels/particledata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np

from parcels._compat import MPI, KMeans
from parcels.tools._helpers import deprecated
from parcels.tools.statuscodes import StatusCode


Expand Down Expand Up @@ -228,12 +229,15 @@ def __len__(self):
"""Return the length, in terms of 'number of elements, of a ParticleData instance."""
return self._ncount

@deprecated(
"Use iter(...) instead, or just use the object in an iterator context (e.g. for p in particledata: ...)."
) # TODO: Remove 6 months after v3.1.0 (or 9 months; doesn't contribute to code debt)
def iterator(self):
return ParticleDataIterator(self)
return iter(self)

def __iter__(self):
"""Return an Iterator that allows for forward iteration over the elements in the ParticleData (e.g. `for p in pset:`)."""
return self.iterator()
return ParticleDataIterator(self)

def __getitem__(self, index):
"""Get a particle object from the ParticleData instance based on its index."""
Expand Down
152 changes: 106 additions & 46 deletions parcels/particlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import parcels
from parcels._compat import MPI
from parcels.tools._helpers import deprecated, deprecated_made_private
from parcels.tools.warnings import FileWarning

__all__ = ["ParticleFile"]
Expand Down Expand Up @@ -46,31 +47,24 @@ class ParticleFile:
ParticleFile object that can be used to write particle data to file
"""

outputdt = None
particleset = None
parcels_mesh = None
time_origin = None
lonlatdepth_dtype = None

def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_zarrfile=True):
self.outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt
self.chunks = chunks
self.particleset = particleset
self.parcels_mesh = "spherical"
self._outputdt = outputdt.total_seconds() if isinstance(outputdt, timedelta) else outputdt
self._chunks = chunks
self._particleset = particleset
self._parcels_mesh = "spherical"
if self.particleset.fieldset is not None:
self.parcels_mesh = self.particleset.fieldset.gridset.grids[0].mesh
self.time_origin = self.particleset.time_origin
self._parcels_mesh = self.particleset.fieldset.gridset.grids[0].mesh
self.lonlatdepth_dtype = self.particleset.particledata.lonlatdepth_dtype
self.maxids = 0
self.pids_written = {}
self.create_new_zarrfile = create_new_zarrfile
self.vars_to_write = {}
self._maxids = 0
self._pids_written = {}
self._create_new_zarrfile = create_new_zarrfile
self._vars_to_write = {}
for var in self.particleset.particledata.ptype.variables:
if var.to_write:
self.vars_to_write[var.name] = var.dtype
self.mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
self._mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
self.particleset.fieldset._particlefile = self
self.analytical = False # Flag to indicate if ParticleFile is used for analytical trajectories
self._is_analytical = False # Flag to indicate if ParticleFile is used for analytical trajectories

# Reset obs_written of each particle, in case new ParticleFile created for a ParticleSet
particleset.particledata.setallvardata("obs_written", 0)
Expand All @@ -80,11 +74,11 @@ def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_z
"Conventions": "CF-1.6/CF-1.7",
"ncei_template_version": "NCEI_NetCDF_Trajectory_Template_v2.0",
"parcels_version": parcels.__version__,
"parcels_mesh": self.parcels_mesh,
"parcels_mesh": self._parcels_mesh,
}

# Create dictionary to translate datatypes and fill_values
self.fill_value_map = {
self._fill_value_map = {
np.float16: np.nan,
np.float32: np.nan,
np.float64: np.nan,
Expand All @@ -103,23 +97,82 @@ def __init__(self, name, particleset, outputdt=np.inf, chunks=None, create_new_z
# But we need to handle incompatibility with MPI mode for now:
if MPI and MPI.COMM_WORLD.Get_size() > 1:
raise ValueError("Currently, MPI mode is not compatible with directly passing a Zarr store.")
self.fname = name
fname = name
else:
extension = os.path.splitext(str(name))[1]
if extension in [".nc", ".nc4"]:
raise RuntimeError(
"Output in NetCDF is not supported anymore. Use .zarr extension for ParticleFile name."
)
if MPI and MPI.COMM_WORLD.Get_size() > 1:
self.fname = os.path.join(name, f"proc{self.mpi_rank:02d}.zarr")
fname = os.path.join(name, f"proc{self._mpi_rank:02d}.zarr")
if extension in [".zarr"]:
warnings.warn(
f"The ParticleFile name contains .zarr extension, but zarr files will be written per processor in MPI mode at {self.fname}",
f"The ParticleFile name contains .zarr extension, but zarr files will be written per processor in MPI mode at {fname}",
FileWarning,
stacklevel=2,
)
else:
self.fname = name if extension in [".zarr"] else f"{name}.zarr"
fname = name if extension in [".zarr"] else f"{name}.zarr"
self._fname = fname

@property
def create_new_zarrfile(self):
return self._create_new_zarrfile

@property
def outputdt(self):
return self._outputdt

@property
def chunks(self):
return self._chunks

@property
def particleset(self):
return self._particleset

@property
def fname(self):
return self._fname

@property
def vars_to_write(self):
return self._vars_to_write

@property
def time_origin(self):
return self.particleset.time_origin

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def parcels_mesh(self):
return self._parcels_mesh

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def maxids(self):
return self._maxids

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def pids_written(self):
return self._pids_written

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def mpi_rank(self):
return self._mpi_rank

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def fill_value_map(self):
return self._fill_value_map

@property
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def analytical(self):
return self._is_analytical

def _create_variables_attribute_dict(self):
"""Creates the dictionary with variable attributes.
Expand All @@ -133,7 +186,7 @@ def _create_variables_attribute_dict(self):
"trajectory": {
"long_name": "Unique identifier for each particle",
"cf_role": "trajectory_id",
"_FillValue": self.fill_value_map[np.int64],
"_FillValue": self._fill_value_map[np.int64],
},
"time": {"long_name": "", "standard_name": "time", "units": "seconds", "axis": "T"},
"lon": {"long_name": "", "standard_name": "longitude", "units": "degrees_east", "axis": "X"},
Expand All @@ -147,14 +200,17 @@ def _create_variables_attribute_dict(self):
for vname in self.vars_to_write:
if vname not in ["time", "lat", "lon", "depth", "id"]:
attrs[vname] = {
"_FillValue": self.fill_value_map[self.vars_to_write[vname]],
"_FillValue": self._fill_value_map[self.vars_to_write[vname]],
"long_name": "",
"standard_name": vname,
"units": "unknown",
}

return attrs

@deprecated(
"ParticleFile.metadata is a dictionary. Use `ParticleFile.metadata['key'] = ...` or other dictionary methods instead."
) # TODO: Remove 6 months after v3.1.0
def add_metadata(self, name, message):
"""Add metadata to :class:`parcels.particleset.ParticleSet`.
Expand All @@ -175,21 +231,25 @@ def _convert_varout_name(self, var):
else:
return var

def write_once(self, var):
@deprecated_made_private # TODO: Remove 6 months after v3.1.0
def write_once(self, *args, **kwargs):
return self._write_once(*args, **kwargs)

def _write_once(self, var):
return self.particleset.particledata.ptype[var].to_write == "once"

def _extend_zarr_dims(self, Z, store, dtype, axis):
if axis == 1:
a = np.full((Z.shape[0], self.chunks[1]), self.fill_value_map[dtype], dtype=dtype)
a = np.full((Z.shape[0], self.chunks[1]), self._fill_value_map[dtype], dtype=dtype)
obs = zarr.group(store=store, overwrite=False)["obs"]
if len(obs) == Z.shape[1]:
obs.append(np.arange(self.chunks[1]) + obs[-1] + 1)
else:
extra_trajs = self.maxids - Z.shape[0]
extra_trajs = self._maxids - Z.shape[0]
if len(Z.shape) == 2:
a = np.full((extra_trajs, Z.shape[1]), self.fill_value_map[dtype], dtype=dtype)
a = np.full((extra_trajs, Z.shape[1]), self._fill_value_map[dtype], dtype=dtype)
else:
a = np.full((extra_trajs,), self.fill_value_map[dtype], dtype=dtype)
a = np.full((extra_trajs,), self._fill_value_map[dtype], dtype=dtype)
Z.append(a, axis=axis)
zarr.consolidate_metadata(store)

Expand Down Expand Up @@ -221,11 +281,11 @@ def write(self, pset, time, indices=None):

if len(indices_to_write) > 0:
pids = pset.particledata.getvardata("id", indices_to_write)
to_add = sorted(set(pids) - set(self.pids_written.keys()))
to_add = sorted(set(pids) - set(self._pids_written.keys()))
for i, pid in enumerate(to_add):
self.pids_written[pid] = self.maxids + i
ids = np.array([self.pids_written[p] for p in pids], dtype=int)
self.maxids = len(self.pids_written)
self._pids_written[pid] = self._maxids + i
ids = np.array([self._pids_written[p] for p in pids], dtype=int)
self._maxids = len(self._pids_written)

once_ids = np.where(pset.particledata.getvardata("obs_written", indices_to_write) == 0)[0]
if len(once_ids) > 0:
Expand All @@ -234,7 +294,7 @@ def write(self, pset, time, indices=None):

if self.create_new_zarrfile:
if self.chunks is None:
self.chunks = (len(ids), 1)
self._chunks = (len(ids), 1)
if pset._repeatpclass is not None and self.chunks[0] < 1e4:
warnings.warn(
f"ParticleFile chunks are set to {self.chunks}, but this may lead to "
Expand All @@ -243,37 +303,37 @@ def write(self, pset, time, indices=None):
FileWarning,
stacklevel=2,
)
if (self.maxids > len(ids)) or (self.maxids > self.chunks[0]):
arrsize = (self.maxids, self.chunks[1])
if (self._maxids > len(ids)) or (self._maxids > self.chunks[0]):
arrsize = (self._maxids, self.chunks[1])
else:
arrsize = (len(ids), self.chunks[1])
ds = xr.Dataset(
attrs=self.metadata,
coords={"trajectory": ("trajectory", pids), "obs": ("obs", np.arange(arrsize[1], dtype=np.int32))},
)
attrs = self._create_variables_attribute_dict()
obs = np.zeros((self.maxids), dtype=np.int32)
obs = np.zeros((self._maxids), dtype=np.int32)
for var in self.vars_to_write:
varout = self._convert_varout_name(var)
if varout not in ["trajectory"]: # because 'trajectory' is written as coordinate
if self.write_once(var):
if self._write_once(var):
data = np.full(
(arrsize[0],),
self.fill_value_map[self.vars_to_write[var]],
self._fill_value_map[self.vars_to_write[var]],
dtype=self.vars_to_write[var],
)
data[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
dims = ["trajectory"]
else:
data = np.full(
arrsize, self.fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]
arrsize, self._fill_value_map[self.vars_to_write[var]], dtype=self.vars_to_write[var]
)
data[ids, 0] = pset.particledata.getvardata(var, indices_to_write)
dims = ["trajectory", "obs"]
ds[varout] = xr.DataArray(data=data, dims=dims, attrs=attrs[varout])
ds[varout].encoding["chunks"] = self.chunks[0] if self.write_once(var) else self.chunks
ds[varout].encoding["chunks"] = self.chunks[0] if self._write_once(var) else self.chunks
ds.to_zarr(self.fname, mode="w")
self.create_new_zarrfile = False
self._create_new_zarrfile = False
else:
# Either use the store that was provided directly or create a DirectoryStore:
if issubclass(type(self.fname), zarr.storage.Store):
Expand All @@ -284,9 +344,9 @@ def write(self, pset, time, indices=None):
obs = pset.particledata.getvardata("obs_written", indices_to_write)
for var in self.vars_to_write:
varout = self._convert_varout_name(var)
if self.maxids > Z[varout].shape[0]:
if self._maxids > Z[varout].shape[0]:
self._extend_zarr_dims(Z[varout], store, dtype=self.vars_to_write[var], axis=0)
if self.write_once(var):
if self._write_once(var):
if len(once_ids) > 0:
Z[varout].vindex[ids_once] = pset.particledata.getvardata(var, indices_to_write_once)
else:
Expand Down
Loading

0 comments on commit aa716e4

Please sign in to comment.