Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

merge "Seafloorgrid contmask" branch #103

Merged
merged 2 commits into from
Jul 31, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
240 changes: 174 additions & 66 deletions gplately/oceans.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
from . import grids
from . import tools


# -------------------------------------------------------------------------
# Auxiliary functions for SeafloorGrid

Expand Down Expand Up @@ -381,6 +382,11 @@ class SeafloorGrid(object):
zval_names : list of str
A list containing string labels for the z values to attribute to points.
Will be used as column headers for z value point dataframes.
continent_mask_filename : str
An optional parameter pointing to the full path to a continental mask for each timestep.
Assuming the time is in the filename, i.e. "/path/to/continent_mask_0Ma.nc", it should be
passed as "/path/to/continent_mask_{}Ma.nc" with curly brackets. Include decimal formatting
if needed.
"""

def __init__(
Expand All @@ -400,6 +406,7 @@ def __init__(
initial_ocean_mean_spreading_rate = 75.,
resume_from_checkpoints = False,
zval_names = ("SPREADING_RATE",),
continent_mask_filename = None,
):

# Provides a rotation model, topology features and reconstruction time for
Expand Down Expand Up @@ -510,6 +517,41 @@ def __init__(
self.default_column_headers = ['CURRENT_LONGITUDES', 'CURRENT_LATITUDES', 'SEAFLOOR_AGE', 'BIRTH_LAT_SNAPSHOT', 'POINT_ID_SNAPSHOT']
self.total_column_headers = np.concatenate([self.default_column_headers, self.zval_names])

# Filename for continental masks that the user can provide instead of building it here
self.continent_mask_filename = continent_mask_filename

# If the user provides a continental mask filename, we need to downsize the mask
# resolution for when we create the initial ocean mesh. The mesh does not need to be high-res.
if self.continent_mask_filename is not None:

# Determine which percentage to use to scale the continent mask resolution at max time
def _map_res_to_node_percentage(
self,
continent_mask_filename
):
maskY, maskX = grids.read_netcdf_grid(
continent_mask_filename.format(
self._max_time
)
).shape

mask_deg = _pixels2deg(maskX, self.extent[0], self.extent[1])

if mask_deg <= 0.1:
percentage = 0.1
elif mask_deg <= 0.25:
percentage = 0.3
elif mask_deg <= 0.5:
percentage = 0.5
elif mask_deg < 0.75:
percentage = 0.6
elif mask_deg >= 1:
percentage = 0.75
return mask_deg, percentage

_, self.percentage = _map_res_to_node_percentage(self, self.continent_mask_filename)


# Allow SeafloorGrid time to be updated, and to update the internally-used
# PlotTopologies' time attribute too. If PlotTopologies is used outside the
# object, its `time` attribute is not updated.
Expand Down Expand Up @@ -617,43 +659,78 @@ def create_initial_ocean_seed_points(self):
A feature collection of point objects on the ocean basin.
"""

# Ensure COB terranes at max time have reconstruction IDs and valid times
COB_polygons = ensure_polygon_geometry(
self._PlotTopologies_object.continents,
self.rotation_model,
self._max_time)

# zval is a binary array encoding whether a point
# coordinate is within a COB terrane polygon or not.
# Use the icosahedral mesh MultiPointOnSphere attribute
_, ocean_basin_point_mesh, zvals = point_in_polygon_routine(
self.icosahedral_multi_point,
COB_polygons
)

# Plates to partition with
plate_partitioner = pygplates.PlatePartitioner(
COB_polygons,
self.rotation_model,
)
if self.continent_mask_filename is None:
# Ensure COB terranes at max time have reconstruction IDs and valid times
COB_polygons = ensure_polygon_geometry(
self._PlotTopologies_object.continents,
self.rotation_model,
self._max_time)

# zval is a binary array encoding whether a point
# coordinate is within a COB terrane polygon or not.
# Use the icosahedral mesh MultiPointOnSphere attribute
_, ocean_basin_point_mesh, zvals = point_in_polygon_routine(
self.icosahedral_multi_point,
COB_polygons
)

# Plates to partition with
plate_partitioner = pygplates.PlatePartitioner(
COB_polygons,
self.rotation_model,
)

# Plate partition the ocean basin points
meshnode_feature = pygplates.Feature(
pygplates.FeatureType.create_from_qualified_string('gpml:MeshNode')
)
meshnode_feature.set_geometry(
ocean_basin_point_mesh
#multi_point
)
ocean_basin_meshnode = pygplates.FeatureCollection(meshnode_feature)
# Plate partition the ocean basin points
meshnode_feature = pygplates.Feature(
pygplates.FeatureType.create_from_qualified_string('gpml:MeshNode')
)
meshnode_feature.set_geometry(
ocean_basin_point_mesh
#multi_point
)
ocean_basin_meshnode = pygplates.FeatureCollection(meshnode_feature)

paleogeography = plate_partitioner.partition_features(
ocean_basin_meshnode, partition_return = pygplates.PartitionReturn.separate_partitioned_and_unpartitioned,
properties_to_copy=[pygplates.PropertyName.gpml_shapefile_attributes]
)
ocean_points = paleogeography[1] # Separate those inside polygons
continent_points = paleogeography[0] # Separate those outside polygons
paleogeography = plate_partitioner.partition_features(
ocean_basin_meshnode, partition_return = pygplates.PartitionReturn.separate_partitioned_and_unpartitioned,
properties_to_copy=[pygplates.PropertyName.gpml_shapefile_attributes]
)
ocean_points = paleogeography[1] # Separate those inside polygons
continent_points = paleogeography[0] # Separate those outside polygons

# If a set of continent masks was passed, we can use max_time's continental
# mask to build the initial profile of seafloor age.
else:
max_time_cont_mask = grids.Raster(
self.continent_mask_filename.format(self._max_time)
)
# If the input grid is at 0.5 degree uniform spacing, then the input
# grid is 7x more populated than a 6-level stripy icosahedral mesh and
# using this resolution for the initial ocean mesh will dramatically slow down
# reconstruction by topologies.
# Scale down the resolution based on the input mask resolution
# (percentage was found in __init__.)
max_time_cont_mask.resize(
int(max_time_cont_mask.shape[0]*self.percentage),
int(max_time_cont_mask.shape[1]*self.percentage),
inplace=True
)

lat = np.linspace(-90,90, max_time_cont_mask.shape[0])
lon = np.linspace(-180,180,max_time_cont_mask.shape[1])

llon, llat = np.meshgrid(lon, lat)

mask_inds = np.where(max_time_cont_mask.data.flatten() == 0)
mask_vals = max_time_cont_mask.data.flatten()
mask_lon = llon.flatten()[mask_inds]
mask_lat = llat.flatten()[mask_inds]

ocean_pt_feature = pygplates.Feature()
ocean_pt_feature.set_geometry(pygplates.MultiPointOnSphere(zip(mask_lat,mask_lon)))
ocean_points = [ocean_pt_feature]


# Now that we have ocean points...
# Determine age of ocean basin points using their proximity to MOR features
# and an assumed globally-uniform ocean basin mean spreading rate.
# We need resolved topologies at the `max_time` to pass into the proximity
Expand Down Expand Up @@ -1074,7 +1151,11 @@ def prepare_for_reconstruction_by_topologies(self):
# - seeding was completed but the subsequent gridding input creation was interrupted,
# seeding is assumed completed and skipped. The workflow automatically proceeds to re-gridding.

self.build_all_continental_masks()
if self.continent_mask_filename is None:
self.build_all_continental_masks()
else:
print("Continent masks passed to SeafloorGrid - skipping continental mask generation!")

self.build_all_MOR_seedpoints()

# ALL-TIME POINTS -----------------------------------------------------
Expand Down Expand Up @@ -1174,18 +1255,31 @@ def reconstruct_by_topologies(self):
]
)
# In addition to the default subduction detection, also detect continental collisions
if self.file_collection is not None:
# Use the input continent mask if it is provided.
if self.continent_mask_filename is not None:
collision_spec = reconstruction._ContinentCollision(
self.save_directory+"/"+self.file_collection+"_continent_mask_{}Ma.nc",

# This filename string should not have a time formatted into it - this is
# taken care of later.
self.continent_mask_filename,
default_collision,
verbose=False,
)
else:
collision_spec = reconstruction._ContinentCollision(
self.save_directory+"/continent_mask_{}Ma.nc",
default_collision,
verbose=False,
)
# If a continent mask is not provided, use the ones made.
if self.file_collection is not None:

collision_spec = reconstruction._ContinentCollision(
self.save_directory+"/"+self.file_collection+"_continent_mask_{}Ma.nc",
default_collision,
verbose=False,
)
else:
collision_spec = reconstruction._ContinentCollision(
self.save_directory+"/continent_mask_{}Ma.nc",
default_collision,
verbose=False,
)

# Call the reconstruct by topologies object
topology_reconstruction = reconstruction._ReconstructByTopologies(
Expand Down Expand Up @@ -1361,6 +1455,7 @@ def lat_lon_z_to_netCDF(
resX=self.spacingX,
resY=self.spacingY,
unmasked=unmasked,
continent_mask_filename=self.continent_mask_filename
)
else:
from joblib import delayed
Expand All @@ -1376,6 +1471,7 @@ def lat_lon_z_to_netCDF(
resX=self.spacingX,
resY=self.spacingY,
unmasked=unmasked,
continent_mask_filename=self.continent_mask_filename
)
for time in time_arr
)
Expand All @@ -1391,6 +1487,7 @@ def _lat_lon_z_to_netCDF_time(
resX,
resY,
unmasked=False,
continent_mask_filename=None
):
# Read the gridding input made by ReconstructByTopologies:
if file_collection is not None:
Expand Down Expand Up @@ -1429,41 +1526,46 @@ def _lat_lon_z_to_netCDF_time(
Z = griddata((lons, lats), zdata, (X, Y), method='nearest')

# Access continental grids from the save directory
if save_directory is not None:
if file_collection is not None:
if file_collection is not None:
if continent_mask_filename is not None:
full_directory=continent_mask_filename.format(time)
else:
full_directory = "{}/{}_continent_mask_{}Ma.nc".format(
save_directory,
file_collection,
time
)
grid_output_unmasked = "{}/{}_{}_grid_unmasked_{}Ma.nc".format(
save_directory,
file_collection,
str(zval_name),
time
)
grid_output_dir = "{}/{}_{}_grid_{}Ma.nc".format(
save_directory,
file_collection,
str(zval_name),
time
)
grid_output_unmasked = "{}/{}_{}_grid_unmasked_{}Ma.nc".format(
save_directory,
file_collection,
str(zval_name),
time
)
grid_output_dir = "{}/{}_{}_grid_{}Ma.nc".format(
save_directory,
file_collection,
str(zval_name),
time
)
else:
if continent_mask_filename is not None:
full_directory=continent_mask_filename.format(time)
else:
full_directory = "{}/{}_continent_mask_{}Ma.nc".format(
save_directory,
zval_name,
time
)
grid_output_unmasked = "{}/{}_grid_unmasked_{}Ma.nc".format(
save_directory,
zval_name,
time
)
grid_output_dir = "{}/{}_grid_{}Ma.nc".format(
save_directory,
zval_name,
time
)
grid_output_unmasked = "{}/{}_grid_unmasked_{}Ma.nc".format(
save_directory,
zval_name,
time
)
grid_output_dir = "{}/{}_grid_{}Ma.nc".format(
save_directory,
zval_name,
time
)

if unmasked:
grids.write_netcdf_grid(
Expand All @@ -1475,6 +1577,12 @@ def _lat_lon_z_to_netCDF_time(
# Identify regions in the grid in the continental mask
cont_mask = grids.Raster(data=str(full_directory))

# We need the continental mask to match the number of nodes
# in the uniform grid defined above. This is important if we
# pass our own continental mask to SeafloorGrid
if cont_mask.shape[1] != resX:
cont_mask.resize(resX, resY, inplace=True)

# Use the continental mask
Z = np.ma.array(
grids.Raster(data=Z).data.data,
Expand Down
Loading