Skip to content

Commit

Permalink
Merge pull request #1566 from OceanParcels/fieldset_from_directory
Browse files Browse the repository at this point in the history
Creating a new `Fieldset.from_modulefile()` method
  • Loading branch information
erikvansebille authored May 28, 2024
2 parents c2bf42c + 846520b commit 3f4168c
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 0 deletions.
28 changes: 28 additions & 0 deletions parcels/fieldset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import importlib.util
import sys
from copy import deepcopy
from glob import glob
from os import path
Expand Down Expand Up @@ -1062,6 +1064,32 @@ def from_xarray_dataset(cls, ds, variables, dimensions, mesh='spherical', allow_
v = fields.pop('V', None)
return cls(u, v, fields=fields)

@classmethod
def from_modulefile(cls, filename, modulename="create_fieldset", **kwargs):
"""Initialises FieldSet data from a file containing a python module file with a create_fieldset() function.
Parameters
----------
filename: path to a python file containing at least a function which returns a FieldSet object.
modulename: name of the function in the python file that returns a FieldSet object. Default is "create_fieldset".
"""
# check if filename exists
if not path.exists(filename):
raise IOError(f"FieldSet module file {filename} does not exist")

# Importing the source file directly (following https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly)
spec = importlib.util.spec_from_file_location(modulename, filename)
fieldset_module = importlib.util.module_from_spec(spec)
sys.modules[modulename] = fieldset_module
spec.loader.exec_module(fieldset_module)

if not hasattr(fieldset_module, modulename):
raise IOError(f"{filename} does not contain a {modulename} function")
fieldset = getattr(fieldset_module, modulename)(**kwargs)
if not isinstance(fieldset, FieldSet):
raise IOError(f"Module {filename}.{modulename} does not return a FieldSet object")
return fieldset

def get_fields(self):
"""Returns a list of all the :class:`parcels.field.Field` and :class:`parcels.field.VectorField`
objects associated with this FieldSet.
Expand Down
18 changes: 18 additions & 0 deletions tests/test_data/fieldset_nemo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from os import path

import parcels


def create_fieldset(indices=None):
data_path = path.join(path.dirname(__file__))

filenames = {'U': {'lon': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
'lat': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
'data': path.join(data_path, 'Uu_eastward_nemo_cross_180lon.nc')},
'V': {'lon': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
'lat': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
'data': path.join(data_path, 'Vv_eastward_nemo_cross_180lon.nc')}}
variables = {'U': 'U', 'V': 'V'}
dimensions = {'lon': 'glamf', 'lat': 'gphif', 'time': 'time_counter'}
indices = indices or {}
return parcels.FieldSet.from_nemo(filenames, variables, dimensions, indices=indices)
21 changes: 21 additions & 0 deletions tests/test_data/fieldset_nemo_error.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from os import path

import parcels


def random_function_name():
data_path = path.join(path.dirname(__file__))

filenames = {'U': {'lon': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
'lat': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
'data': path.join(data_path, 'Uu_eastward_nemo_cross_180lon.nc')},
'V': {'lon': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
'lat': path.join(data_path, 'mask_nemo_cross_180lon.nc'),
'data': path.join(data_path, 'Vv_eastward_nemo_cross_180lon.nc')}}
variables = {'U': 'U', 'V': 'V'}
dimensions = {'lon': 'glamf', 'lat': 'gphif', 'time': 'time_counter'}
return parcels.FieldSet.from_nemo(filenames, variables, dimensions)


def none_returning_function():
return None
18 changes: 18 additions & 0 deletions tests/test_fieldset.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,24 @@ def test_field_from_netcdf(with_timestamps):
Field.from_netcdf(filenames, variable, dimensions, interp_method='cgrid_velocity')


def test_fieldset_from_modulefile():
data_path = path.join(path.dirname(__file__), 'test_data/')
fieldset = FieldSet.from_modulefile(data_path + 'fieldset_nemo.py')
assert fieldset.U.creation_log == 'from_nemo'

indices = {'lon': range(6, 10)}
fieldset = FieldSet.from_modulefile(data_path + 'fieldset_nemo.py', indices=indices)
assert fieldset.U.grid.lon.shape[1] == 4

with pytest.raises(IOError):
FieldSet.from_modulefile(data_path + 'fieldset_nemo_error.py')

FieldSet.from_modulefile(data_path + 'fieldset_nemo_error.py', modulename='random_function_name')

with pytest.raises(IOError):
FieldSet.from_modulefile(data_path + 'fieldset_nemo_error.py', modulename='none_returning_function')


def test_field_from_netcdf_fieldtypes():
data_path = path.join(path.dirname(__file__), 'test_data/')

Expand Down

0 comments on commit 3f4168c

Please sign in to comment.