Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor seafloor age gridding workflow #271

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
10 changes: 0 additions & 10 deletions gplately/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,27 +232,20 @@
pygplates,
reconstruction,
)
from .data import DataCollection
from .download import DataServer
from .grids import Raster
from .oceans import SeafloorGrid
from .plot import PlotTopologies
from .reconstruction import (
PlateReconstruction,
Points,
_ContinentCollision,
_DefaultCollision,
_ReconstructByTopologies,
)
from .tools import EARTH_RADIUS
from .utils import io_utils
from .utils.io_utils import get_geometries, get_valid_geometries

__pdoc__ = {
"data": False,
"_DefaultCollision": False,
"_ContinentCollision": False,
"_ReconstructByTopologies": False,
"examples": False,
"notebooks": False,
"commands": False,
Expand Down Expand Up @@ -283,9 +276,6 @@
"Points",
"Raster",
"SeafloorGrid",
"_ContinentCollision",
"_DefaultCollision",
"_ReconstructByTopologies",
# Functions
"get_geometries",
"get_valid_geometries",
Expand Down
228 changes: 136 additions & 92 deletions gplately/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def realign_grid(array, lons, lats):
return array, lons, lats


def read_netcdf_grid(filename, return_grids=False, realign=False, resample=None):
def read_netcdf_grid(filename, return_grids=False, realign=False, resample=None, resize=None):
"""Read a `netCDF` (.nc) grid from a given `filename` and return its data as a
`MaskedArray`.

Expand Down Expand Up @@ -174,6 +174,10 @@ def read_netcdf_grid(filename, return_grids=False, realign=False, resample=None)
If passed as `resample = (spacingX, spacingY)`, the given `netCDF` grid is resampled
with these x and y resolutions.

resize : tuple, optional, default=None
If passed as `resample = (resX, resY)`, the given `netCDF` grid is resized
to the number of columns (resX) and rows (resY).

Returns
-------
grid_z : MaskedArray
Expand Down Expand Up @@ -237,6 +241,18 @@ def find_label(keys, labels):
cdf_lon = cdf[key_lon][:]
cdf_lat = cdf[key_lat][:]

# fill missing values
if hasattr(cdf[key_z], 'missing_value') and np.issubdtype(cdf_grid.dtype, np.floating):
fill_value = cdf[key_z].missing_value
cdf_grid[np.isclose(cdf_grid, fill_value, rtol=0.1)] = np.nan

# convert to boolean array
if np.issubdtype(cdf_grid.dtype, np.integer):
unique_grid = np.unique(cdf_grid)
if len(unique_grid) == 2:
if (unique_grid == [0,1]).all():
cdf_grid = cdf_grid.astype(bool)

if realign:
# realign longitudes to -180/180 dateline
cdf_grid_z, cdf_lon, cdf_lat = realign_grid(cdf_grid, cdf_lon, cdf_lat)
Expand All @@ -246,25 +262,58 @@ def find_label(keys, labels):
# resample
if resample is not None:
spacingX, spacingY = resample
lon_grid = np.arange(cdf_lon.min(), cdf_lon.max() + spacingX, spacingX)
lat_grid = np.arange(cdf_lat.min(), cdf_lat.max() + spacingY, spacingY)
lonq, latq = np.meshgrid(lon_grid, lat_grid)
original_extent = (
cdf_lon[0],
cdf_lon[-1],
cdf_lat[0],
cdf_lat[-1],
)
cdf_grid_z = sample_grid(
lonq,
latq,
cdf_grid_z,
method="nearest",
extent=original_extent,
return_indices=False,
)
cdf_lon = lon_grid
cdf_lat = lat_grid

# don't resample if already the same resolution
dX = np.diff(cdf_lon).mean()
dY = np.diff(cdf_lat).mean()

if spacingX != dX or spacingY != dY:
lon_grid = np.arange(cdf_lon.min(), cdf_lon.max() + spacingX, spacingX)
lat_grid = np.arange(cdf_lat.min(), cdf_lat.max() + spacingY, spacingY)
lonq, latq = np.meshgrid(lon_grid, lat_grid)
original_extent = (
cdf_lon[0],
cdf_lon[-1],
cdf_lat[0],
cdf_lat[-1],
)
cdf_grid_z = sample_grid(
lonq,
latq,
cdf_grid_z,
method="nearest",
extent=original_extent,
return_indices=False,
)
cdf_lon = lon_grid
cdf_lat = lat_grid

# resize
if resize is not None:
resX, resY = resize

# don't resize if already the same shape
if resX != cdf_grid_z.shape[1] or resY != cdf_grid_z.shape[0]:
original_extent = (
cdf_lon[0],
cdf_lon[-1],
cdf_lat[0],
cdf_lat[-1],
)
lon_grid = np.linspace(original_extent[0], original_extent[1], resX)
lat_grid = np.linspace(original_extent[2], original_extent[3], resY)
lonq, latq = np.meshgrid(lon_grid, lat_grid)

cdf_grid_z = sample_grid(
lonq,
latq,
cdf_grid_z,
method="nearest",
extent=original_extent,
return_indices=False,
)
cdf_lon = lon_grid
cdf_lat = lat_grid

# Fix grids with 9e36 as the fill value for nan.
# cdf_grid_z.fill_value = float('nan')
Expand All @@ -274,61 +323,9 @@ def find_label(keys, labels):
return cdf_grid_z, cdf_lon, cdf_lat
else:
return cdf_grid_z


def write_netcdf(filename, lons, lats, data):
"""Write geospatial data to a netCDF4 grid with a specified `filename`.
The latitude, longitude, and data variabels must be of the same size.

Parameters
----------
filename : str
The full path (including a filename and the ".nc" extension) to save the created netCDF4 file.

lons : 1D array
List of longitudinal coordinates to be written into a netCDF4 (.nc) file.

lats : 1D array
List of latitudinal coordinates to be written into a netCDF4 (.nc) file.

data : 1D array
List of data values at lon / lat coordinates to be written into a netCDF4 (.nc) file.
"""
import netCDF4

lons = np.asarray(lons)
lats = np.asarray(lats)
data = np.asarray(data)

with netCDF4.Dataset(filename, "w", driver=None) as cdf:
cdf.title = "Grid produced by gplately"
cdf.createDimension("lon", lons.size)
cdf_lon = cdf.createVariable("lon", lons.dtype, ("lon",), zlib=True)
cdf_lat = cdf.createVariable("lat", lats.dtype, ("lon",), zlib=True)
cdf_lon[:] = lons
cdf_lat[:] = lats

# Units for Geographic Grid type
cdf_lon.units = "degrees_east"
cdf_lon.standard_name = "lon"
cdf_lon.actual_range = [np.min(lons), np.max(lons)]
cdf_lat.units = "degrees_north"
cdf_lat.standard_name = "lat"
cdf_lat.actual_range = [np.min(lats), np.max(lats)]

cdf_data = cdf.createVariable("z", data.dtype, ("lon",), zlib=True)
# netCDF4 uses the missing_value attribute as the default _FillValue
# without this, _FillValue defaults to 9.969209968386869e+36
cdf_data.missing_value = np.nan
cdf_data.standard_name = "z"
# Ensure pygmt registers min and max z values properly
cdf_data.actual_range = [np.nanmin(data), np.nanmax(data)]

cdf_data[:] = data


def write_netcdf_grid(filename, grid, extent=[-180, 180, -90, 90]):
"""Write geological data contained in a `grid` to a netCDF4 grid with a specified `filename`.

def write_netcdf_grid(filename, grid, extent="global", significant_digits=None, fill_value=np.nan, **kwargs):
""" Write geological data contained in a `grid` to a netCDF4 grid with a specified `filename`.

Notes
-----
Expand All @@ -352,28 +349,36 @@ def write_netcdf_grid(filename, grid, extent=[-180, 180, -90, 90]):
An ndarray grid containing data to be written into a `netCDF` (.nc) file. Note: Rows correspond to
the data's latitudes, while the columns correspond to the data's longitudes.

extent : 1D numpy array, default=[-180,180,-90,90]
Four elements that specify the [min lon, max lon, min lat, max lat] to constrain the lat and lon
variables of the netCDF grid to. If no extents are supplied, full global extent `[-180, 180, -90, 90]`
is assumed.
extent : list, default=[-180,180,-90,90]
Four elements that specify the [min lon, max lon, min lat, max lat] to constrain the lat and lon
variables of the netCDF grid to. If no extents are supplied, full global extent `[-180, 180, -90, 90]`
is assumed.

Returns
-------
A netCDF grid will be saved to the path specified in `filename`.
"""
import netCDF4
from gplately import __version__ as _version

if extent == 'global':
extent = [-180, 180, -90, 90]
else:
assert len(extent) == 4, "specify the [min lon, max lon, min lat, max lat]"

nrows, ncols = np.shape(grid)

lon_grid = np.linspace(extent[0], extent[1], ncols)
lat_grid = np.linspace(extent[2], extent[3], nrows)

with netCDF4.Dataset(filename, "w", driver=None) as cdf:
cdf.title = "Grid produced by gplately"
cdf.createDimension("lon", lon_grid.size)
cdf.createDimension("lat", lat_grid.size)
cdf_lon = cdf.createVariable("lon", lon_grid.dtype, ("lon",), zlib=True)
cdf_lat = cdf.createVariable("lat", lat_grid.dtype, ("lat",), zlib=True)
data_kwds = {'compression': 'zlib', 'complevel': 9}

with netCDF4.Dataset(filename, 'w', driver=None) as cdf:
cdf.title = "Grid produced by gplately " + str(_version)
cdf.createDimension('lon', lon_grid.size)
cdf.createDimension('lat', lat_grid.size)
cdf_lon = cdf.createVariable('lon', lon_grid.dtype, ('lon',), **data_kwds)
cdf_lat = cdf.createVariable('lat', lat_grid.dtype, ('lat',), **data_kwds)
cdf_lon[:] = lon_grid
cdf_lat[:] = lat_grid

Expand All @@ -385,15 +390,45 @@ def write_netcdf_grid(filename, grid, extent=[-180, 180, -90, 90]):
cdf_lat.standard_name = "lat"
cdf_lat.actual_range = [lat_grid[0], lat_grid[-1]]

cdf_data = cdf.createVariable("z", grid.dtype, ("lat", "lon"), zlib=True)
# create container variable for CRS: lon/lat WGS84 datum
crso = cdf.createVariable('crs','i4')
crso.long_name = 'Lon/Lat Coords in WGS84'
crso.grid_mapping_name='latitude_longitude'
crso.longitude_of_prime_meridian = 0.0
crso.semi_major_axis = 6378137.0
crso.inverse_flattening = 298.257223563
crso.spatial_ref = """GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.01745329251994328,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]]"""

# add more keyword arguments for quantizing data
if significant_digits:
# significant_digits needs to be >= 2 so that NaNs are preserved
data_kwds['significant_digits'] = max(2, int(significant_digits))
data_kwds['quantize_mode'] = 'GranularBitRound'

# boolean arrays need to be converted to integers
# no such thing as a mask on a boolean array
if grid.dtype is np.dtype(bool):
grid = grid.astype('i1')
fill_value = None

cdf_data = cdf.createVariable('z', grid.dtype, ('lat','lon'), **data_kwds)

# netCDF4 uses the missing_value attribute as the default _FillValue
# without this, _FillValue defaults to 9.969209968386869e+36
cdf_data.missing_value = np.nan
cdf_data.standard_name = "z"
# Ensure pygmt registers min and max z values properly
if fill_value is not None:
cdf_data.missing_value = fill_value

cdf_data.standard_name = 'z'

cdf_data.add_offset = 0.0
cdf_data.grid_mapping = 'crs'
cdf_data.set_auto_maskandscale(False)

# ensure min and max z values are properly registered
cdf_data.actual_range = [np.nanmin(grid), np.nanmax(grid)]

cdf_data[:, :] = grid
# write data
cdf_data[:,:] = grid


class RegularGridInterpolator(_RGI):
Expand Down Expand Up @@ -1511,6 +1546,7 @@ def __init__(
extent="global",
realign=False,
resample=None,
resize=None,
time=0.0,
origin=None,
**kwargs,
Expand Down Expand Up @@ -1540,6 +1576,10 @@ def __init__(
Optionally resample grid, pass spacing in X and Y direction as a
2-tuple e.g. resample=(spacingX, spacingY).

resize : 2-tuple, optional
Optionally resample grid to X-columns, Y-rows as a
2-tuple e.g. resample=(resX, resY).

time : float, default: 0.0
The time step represented by the raster data. Used for raster
reconstruction.
Expand Down Expand Up @@ -1610,6 +1650,7 @@ def __init__(
return_grids=True,
realign=realign,
resample=resample,
resize=resize,
)
self._lons = lons
self._lats = lats
Expand All @@ -1631,6 +1672,9 @@ def __init__(
if (not isinstance(data, str)) and (resample is not None):
self.resample(*resample, inplace=True)

if (not isinstance(data, str)) and (resize is not None):
self.resize(*resize, inplace=True)

@property
def time(self):
"""The time step represented by the raster data."""
Expand Down Expand Up @@ -1973,10 +2017,10 @@ def fill_NaNs(self, inplace=False, return_array=False):
else:
return Raster(data, self.plate_reconstruction, self.extent, self.time)

def save_to_netcdf4(self, filename):
def save_to_netcdf4(self, filename, significant_digits=None, fill_value=np.nan):
"""Saves the grid attributed to the `Raster` object to the given `filename` (including
the ".nc" extension) in netCDF4 format."""
write_netcdf_grid(str(filename), self.data, self.extent)
write_netcdf_grid(str(filename), self.data, self.extent, significant_digits, fill_value)

def reconstruct(
self,
Expand Down
Loading
Loading