From 62e8e2127301fd30148a8e887134f6987ebb42e7 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 29 Jan 2024 08:08:14 +0100 Subject: [PATCH 01/12] Creating ParticleClass.add_variable method And updating parcels_tutorial and test_particlesets --- docs/examples/parcels_tutorial.ipynb | 95 +++++++++++++--------------- parcels/particle.py | 10 +++ tests/test_particlesets.py | 22 +++++-- 3 files changed, 69 insertions(+), 58 deletions(-) diff --git a/docs/examples/parcels_tutorial.ipynb b/docs/examples/parcels_tutorial.ipynb index 98c823099..d9dea686b 100644 --- a/docs/examples/parcels_tutorial.ipynb +++ b/docs/examples/parcels_tutorial.ipynb @@ -54,7 +54,6 @@ " FieldSet,\n", " JITParticle,\n", " ParticleSet,\n", - " Variable,\n", " download_example_dataset,\n", ")" ] @@ -228,7 +227,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in EddyParticles.zarr.\n", - "100%|██████████| 518400.0/518400.0 [00:02<00:00, 187445.76it/s]\n" + "100%|██████████| 518400.0/518400.0 [00:02<00:00, 176997.87it/s]\n" ] } ], @@ -569,42 +568,42 @@ "\n", "\n", "
\n", - " \n", + " \n", "
\n", - " \n", + " oninput=\"anim2a02ae06648044e08b637ea620eb6c63.set_frame(parseInt(this.value));\">\n", "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", "
\n", - "
\n", - " \n", - " \n", - " Once\n", + " \n", - " \n", - " Loop\n", + " \n", - " \n", + " \n", "
\n", "
\n", "
\n", @@ -614,9 +613,9 @@ " /* Instantiate the Animation class. */\n", " /* The IDs given should match those used in the template above. */\n", " (function() {\n", - " var img_id = \"_anim_img8bfbea7abfae4cbf885a957f3f62a93b\";\n", - " var slider_id = \"_anim_slider8bfbea7abfae4cbf885a957f3f62a93b\";\n", - " var loop_select_id = \"_anim_loop_select8bfbea7abfae4cbf885a957f3f62a93b\";\n", + " var img_id = \"_anim_img2a02ae06648044e08b637ea620eb6c63\";\n", + " var slider_id = \"_anim_slider2a02ae06648044e08b637ea620eb6c63\";\n", + " var loop_select_id = \"_anim_loop_select2a02ae06648044e08b637ea620eb6c63\";\n", " var frames = new Array(29);\n", " \n", " frames[0] = \"\\\n", @@ -13156,7 +13155,7 @@ " /* set a timeout to make sure all the above elements are created before\n", " the object is initialized. */\n", " setTimeout(function() {\n", - " anim8bfbea7abfae4cbf885a957f3f62a93b = new Animation(frames, img_id, slider_id, 100.0,\n", + " anim2a02ae06648044e08b637ea620eb6c63 = new Animation(frames, img_id, slider_id, 100.0,\n", " loop_select_id);\n", " }, 0);\n", " })()\n", @@ -13201,7 +13200,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in EddyParticles_Bwd.zarr.\n", - "100%|██████████| 518400.0/518400.0 [00:02<00:00, 188464.76it/s]\n" + "100%|██████████| 518400.0/518400.0 [00:02<00:00, 176188.58it/s]\n" ] } ], @@ -13323,7 +13322,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in EddyParticles_WestVel.zarr.\n", - "100%|██████████| 172800.0/172800.0 [00:00<00:00, 180626.58it/s]\n" + "100%|██████████| 172800.0/172800.0 [00:01<00:00, 171451.01it/s]\n" ] } ], @@ -13503,7 +13502,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in GlobCurrentParticles.zarr.\n", - "100%|██████████| 864000.0/864000.0 [00:00<00:00, 1104294.63it/s]\n" + "100%|██████████| 864000.0/864000.0 [00:00<00:00, 1057631.05it/s]\n" ] } ], @@ -13592,7 +13591,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now define a new `Particle` class that has an extra `Variable`: the pressure. We initialise this by sampling the `fieldset.P` field.\n" + "Now add a new variable `p` to the Particle Class to store the pressure by sampling the `fieldset.P` field.\n" ] }, { @@ -13601,11 +13600,7 @@ "metadata": {}, "outputs": [], "source": [ - "class SampleParticle(JITParticle):\n", - " \"\"\"Define a new particle class with variable 'p'\n", - " initialised by sampling the pressure\"\"\"\n", - "\n", - " p = Variable(\"p\")" + "JITParticle.add_variable(\"p\") # add variable p to JITParticles to store the pressure field" ] }, { @@ -13635,7 +13630,7 @@ "source": [ "pset = ParticleSet.from_line(\n", " fieldset=fieldset,\n", - " pclass=SampleParticle,\n", + " pclass=JITParticle,\n", " start=(3000, 3000),\n", " finish=(3000, 46000),\n", " size=5,\n", @@ -13688,7 +13683,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in PeninsulaPressure.zarr.\n", - "100%|██████████| 72000.0/72000.0 [00:00<00:00, 157840.98it/s]\n" + "100%|██████████| 72000.0/72000.0 [00:00<00:00, 148175.61it/s]\n" ] } ], @@ -13768,7 +13763,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "First, we need to create a new `Particle` class that includes three extra variables. The `distance` variable will be written to output, but the auxiliary variables `prev_lon` and `prev_lat` won't be written to output (can be controlled using the `to_write` keyword)\n" + "First, we need to add three extra variables to the Particle Class. The `distance` variable will be written to output, but the auxiliary variables `prev_lon` and `prev_lat` won't be written to output (can be controlled using the `to_write` keyword)." ] }, { @@ -13777,19 +13772,17 @@ "metadata": {}, "outputs": [], "source": [ - "class DistParticle(JITParticle):\n", - " \"\"\"Define a new particle class that contains three extra variables\"\"\"\n", + "JITParticle.add_variable(\"distance\", initial=0.0, dtype=np.float32)\n", "\n", - " # the distance travelled by the particle\n", - " distance = Variable(\"distance\", initial=0.0, dtype=np.float32)\n", + "JITParticle.add_variable(\"prev_lon\",\n", + " dtype=np.float32,\n", + " to_write=False,\n", + " initial=attrgetter(\"lon\"))\n", "\n", - " # the previous longitude and latitude of the particle\n", - " prev_lon = Variable(\n", - " \"prev_lon\", dtype=np.float32, to_write=False, initial=attrgetter(\"lon\")\n", - " )\n", - " prev_lat = Variable(\n", - " \"prev_lat\", dtype=np.float32, to_write=False, initial=attrgetter(\"lat\")\n", - " )" + "JITParticle.add_variable(\"prev_lat\",\n", + " dtype=np.float32,\n", + " to_write=False,\n", + " initial=attrgetter(\"lat\"))" ] }, { @@ -13857,7 +13850,7 @@ "dimensions = {\"lat\": \"lat\", \"lon\": \"lon\", \"time\": \"time\"}\n", "fieldset = FieldSet.from_netcdf(filenames, variables, dimensions)\n", "pset = ParticleSet.from_line(\n", - " fieldset=fieldset, pclass=DistParticle, size=5, start=(28, -33), finish=(30, -33)\n", + " fieldset=fieldset, pclass=JITParticle, size=5, start=(28, -33), finish=(30, -33)\n", ")" ] }, @@ -13879,7 +13872,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in GlobCurrentParticles_Dist.zarr.\n", - "100%|██████████| 518400.0/518400.0 [00:03<00:00, 147724.13it/s]\n" + "100%|██████████| 518400.0/518400.0 [00:04<00:00, 118492.65it/s]\n" ] } ], @@ -13936,7 +13929,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.9" + "version": "3.11.6" } }, "nbformat": 4, diff --git a/parcels/particle.py b/parcels/particle.py index 0e4765f34..066458780 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -229,6 +229,16 @@ def __repr__(self): str += f"{var}={getattr(self, var):f}, " return str + f"time={time_string})" + @classmethod + def add_variable(cls, var, *args, **kwargs): + """Add a new variable to the Particle class""" + if not isinstance(var, Variable): + dtype = kwargs.pop('dtype', np.float32) + initial = kwargs.pop('initial', 0) + to_write = kwargs.pop('to_write', True) + var = Variable(var, dtype=dtype, initial=initial, to_write=to_write) + setattr(cls, var.name, var) + @classmethod def set_lonlatdepth_dtype(cls, dtype): cls.lon.dtype = dtype diff --git a/tests/test_particlesets.py b/tests/test_particlesets.py index 54074a52f..a74421f4c 100644 --- a/tests/test_particlesets.py +++ b/tests/test_particlesets.py @@ -58,7 +58,9 @@ def test_pset_create_list_with_customvariable(fieldset, mode, npart=100): lat = np.linspace(1, 0, npart, dtype=np.float32) class MyParticle(ptype[mode]): - v = Variable('v') + pass + + MyParticle.add_variable(Variable("v")) v_vals = np.arange(npart) pset = ParticleSet.from_list(fieldset, lon=lon, lat=lat, v=v_vals, pclass=MyParticle) @@ -75,9 +77,11 @@ def test_pset_create_fromparticlefile(fieldset, mode, restart, tmpdir): lat = np.linspace(1, 0, 10, dtype=np.float32) class TestParticle(ptype[mode]): - p = Variable('p', np.float32, initial=0.33) - p2 = Variable('p2', np.float32, initial=1, to_write=False) - p3 = Variable('p3', np.float32, to_write='once') + pass + + TestParticle.add_variable('p', np.float32, initial=0.33) + TestParticle.add_variable('p2', np.float32, initial=1, to_write=False) + TestParticle.add_variable('p3', np.float32, to_write='once') pset = ParticleSet(fieldset, lon=lon, lat=lat, depth=[4]*len(lon), pclass=TestParticle, p3=np.arange(len(lon))) pfile = pset.ParticleFile(filename, outputdt=1) @@ -237,8 +241,10 @@ def test_pset_access(fieldset, mode, npart=100): @pytest.mark.parametrize('mode', ['scipy', 'jit']) def test_pset_custom_ptype(fieldset, mode, npart=100): class TestParticle(ptype[mode]): - p = Variable('p', np.float32, initial=0.33) - n = Variable('n', np.int32, initial=2) + pass + + TestParticle.add_variable('p', np.float32, initial=0.33) + TestParticle.add_variable('n', np.int32, initial=2) pset = ParticleSet(fieldset, pclass=TestParticle, lon=np.linspace(0, 1, npart), @@ -422,7 +428,9 @@ def test_from_field_exact_val(staggered_grid): fieldset.add_field(FMask) class SampleParticle(ptype['scipy']): - mask = Variable('mask', initial=0) + pass + + SampleParticle.add_variable('mask', initial=0) def SampleMask(particle, fieldset, time): particle.mask = fieldset.mask[particle] From 3017d84d28e7db92f05d3ed0b74c62095c25934f Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 29 Jan 2024 10:12:45 +0100 Subject: [PATCH 02/12] Updating add_variable to return new class This is a cleaner way than updating the class itself. Also adding add_variables method for a list of Variables --- docs/examples/parcels_tutorial.ipynb | 81 +++++++++++++--------------- parcels/particle.py | 34 +++++++++++- tests/test_particlesets.py | 25 +++------ 3 files changed, 77 insertions(+), 63 deletions(-) diff --git a/docs/examples/parcels_tutorial.ipynb b/docs/examples/parcels_tutorial.ipynb index d9dea686b..aea33d664 100644 --- a/docs/examples/parcels_tutorial.ipynb +++ b/docs/examples/parcels_tutorial.ipynb @@ -54,6 +54,7 @@ " FieldSet,\n", " JITParticle,\n", " ParticleSet,\n", + " Variable,\n", " download_example_dataset,\n", ")" ] @@ -227,7 +228,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in EddyParticles.zarr.\n", - "100%|██████████| 518400.0/518400.0 [00:02<00:00, 176997.87it/s]\n" + "100%|██████████| 518400.0/518400.0 [00:03<00:00, 172443.78it/s]\n" ] } ], @@ -568,42 +569,42 @@ "\n", "\n", "
\n", - " \n", + " \n", "
\n", - " \n", + " oninput=\"anim39fc9ac9ecc54c368c0d0e047c7673e6.set_frame(parseInt(this.value));\">\n", "
\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", "
\n", - "
\n", - " \n", - " \n", - " Once\n", + " \n", - " \n", - " Loop\n", + " \n", - " \n", + " \n", "
\n", "
\n", "
\n", @@ -613,9 +614,9 @@ " /* Instantiate the Animation class. */\n", " /* The IDs given should match those used in the template above. */\n", " (function() {\n", - " var img_id = \"_anim_img2a02ae06648044e08b637ea620eb6c63\";\n", - " var slider_id = \"_anim_slider2a02ae06648044e08b637ea620eb6c63\";\n", - " var loop_select_id = \"_anim_loop_select2a02ae06648044e08b637ea620eb6c63\";\n", + " var img_id = \"_anim_img39fc9ac9ecc54c368c0d0e047c7673e6\";\n", + " var slider_id = \"_anim_slider39fc9ac9ecc54c368c0d0e047c7673e6\";\n", + " var loop_select_id = \"_anim_loop_select39fc9ac9ecc54c368c0d0e047c7673e6\";\n", " var frames = new Array(29);\n", " \n", " frames[0] = \"\\\n", @@ -13155,7 +13156,7 @@ " /* set a timeout to make sure all the above elements are created before\n", " the object is initialized. */\n", " setTimeout(function() {\n", - " anim2a02ae06648044e08b637ea620eb6c63 = new Animation(frames, img_id, slider_id, 100.0,\n", + " anim39fc9ac9ecc54c368c0d0e047c7673e6 = new Animation(frames, img_id, slider_id, 100.0,\n", " loop_select_id);\n", " }, 0);\n", " })()\n", @@ -13200,7 +13201,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in EddyParticles_Bwd.zarr.\n", - "100%|██████████| 518400.0/518400.0 [00:02<00:00, 176188.58it/s]\n" + "100%|██████████| 518400.0/518400.0 [00:02<00:00, 176426.86it/s]\n" ] } ], @@ -13322,7 +13323,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in EddyParticles_WestVel.zarr.\n", - "100%|██████████| 172800.0/172800.0 [00:01<00:00, 171451.01it/s]\n" + "100%|██████████| 172800.0/172800.0 [00:00<00:00, 179532.85it/s]\n" ] } ], @@ -13502,7 +13503,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in GlobCurrentParticles.zarr.\n", - "100%|██████████| 864000.0/864000.0 [00:00<00:00, 1057631.05it/s]\n" + "100%|██████████| 864000.0/864000.0 [00:00<00:00, 1072517.72it/s]\n" ] } ], @@ -13591,7 +13592,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now add a new variable `p` to the Particle Class to store the pressure by sampling the `fieldset.P` field.\n" + "Now define a new `Particle` class that has an extra `Variable`: the pressure. This `particle.p` can be used to store the values of the `fieldset.P` field at the particle locations.\n" ] }, { @@ -13600,7 +13601,7 @@ "metadata": {}, "outputs": [], "source": [ - "JITParticle.add_variable(\"p\") # add variable p to JITParticles to store the pressure field" + "SampleParticle = JITParticle.add_variable(\"p\") # add variable p to Particle class to store the pressure field" ] }, { @@ -13630,7 +13631,7 @@ "source": [ "pset = ParticleSet.from_line(\n", " fieldset=fieldset,\n", - " pclass=JITParticle,\n", + " pclass=SampleParticle,\n", " start=(3000, 3000),\n", " finish=(3000, 46000),\n", " size=5,\n", @@ -13683,7 +13684,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in PeninsulaPressure.zarr.\n", - "100%|██████████| 72000.0/72000.0 [00:00<00:00, 148175.61it/s]\n" + "100%|██████████| 72000.0/72000.0 [00:00<00:00, 143717.52it/s]\n" ] } ], @@ -13772,17 +13773,11 @@ "metadata": {}, "outputs": [], "source": [ - "JITParticle.add_variable(\"distance\", initial=0.0, dtype=np.float32)\n", - "\n", - "JITParticle.add_variable(\"prev_lon\",\n", - " dtype=np.float32,\n", - " to_write=False,\n", - " initial=attrgetter(\"lon\"))\n", + "extra_vars = [Variable(\"distance\", initial=0.0, dtype=np.float32),\n", + " Variable(\"prev_lon\", dtype=np.float32, to_write=False, initial=attrgetter(\"lon\")),\n", + " Variable(\"prev_lat\", dtype=np.float32, to_write=False, initial=attrgetter(\"lat\"))]\n", "\n", - "JITParticle.add_variable(\"prev_lat\",\n", - " dtype=np.float32,\n", - " to_write=False,\n", - " initial=attrgetter(\"lat\"))" + "DistParticle = JITParticle.add_variables(extra_vars)" ] }, { @@ -13850,7 +13845,7 @@ "dimensions = {\"lat\": \"lat\", \"lon\": \"lon\", \"time\": \"time\"}\n", "fieldset = FieldSet.from_netcdf(filenames, variables, dimensions)\n", "pset = ParticleSet.from_line(\n", - " fieldset=fieldset, pclass=JITParticle, size=5, start=(28, -33), finish=(30, -33)\n", + " fieldset=fieldset, pclass=DistParticle, size=5, start=(28, -33), finish=(30, -33)\n", ")" ] }, @@ -13872,7 +13867,7 @@ "output_type": "stream", "text": [ "INFO: Output files are stored in GlobCurrentParticles_Dist.zarr.\n", - "100%|██████████| 518400.0/518400.0 [00:04<00:00, 118492.65it/s]\n" + "100%|██████████| 518400.0/518400.0 [00:03<00:00, 136275.28it/s]\n" ] } ], diff --git a/parcels/particle.py b/parcels/particle.py index 066458780..0cc3a6c32 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -231,13 +231,43 @@ def __repr__(self): @classmethod def add_variable(cls, var, *args, **kwargs): - """Add a new variable to the Particle class""" + """Add a new variable to the Particle class + + Parameters + ---------- + var : str, Variable or list of Variables + Variable object to be added. Can be the name of the Variable, + a Variable object, or a list of Variable objects + """ + + if isinstance(var, list): + return cls.add_variables(var) if not isinstance(var, Variable): dtype = kwargs.pop('dtype', np.float32) initial = kwargs.pop('initial', 0) to_write = kwargs.pop('to_write', True) var = Variable(var, dtype=dtype, initial=initial, to_write=to_write) - setattr(cls, var.name, var) + + class NewParticle(cls): + pass + + setattr(NewParticle, var.name, var) + return NewParticle + + @classmethod + def add_variables(cls, variables): + """Add multiple new variables to the Particle class + + Parameters + ---------- + variables : list of Variable + Variable objects to be added. Has to be a list of Variable objects + """ + + NewParticle = cls + for var in variables: + NewParticle = NewParticle.add_variable(var) + return NewParticle @classmethod def set_lonlatdepth_dtype(cls, dtype): diff --git a/tests/test_particlesets.py b/tests/test_particlesets.py index a74421f4c..e7204856b 100644 --- a/tests/test_particlesets.py +++ b/tests/test_particlesets.py @@ -57,10 +57,7 @@ def test_pset_create_list_with_customvariable(fieldset, mode, npart=100): lon = np.linspace(0, 1, npart, dtype=np.float32) lat = np.linspace(1, 0, npart, dtype=np.float32) - class MyParticle(ptype[mode]): - pass - - MyParticle.add_variable(Variable("v")) + MyParticle = ptype[mode].add_variable(Variable("v")) v_vals = np.arange(npart) pset = ParticleSet.from_list(fieldset, lon=lon, lat=lat, v=v_vals, pclass=MyParticle) @@ -76,12 +73,9 @@ def test_pset_create_fromparticlefile(fieldset, mode, restart, tmpdir): lon = np.linspace(0, 1, 10, dtype=np.float32) lat = np.linspace(1, 0, 10, dtype=np.float32) - class TestParticle(ptype[mode]): - pass - - TestParticle.add_variable('p', np.float32, initial=0.33) - TestParticle.add_variable('p2', np.float32, initial=1, to_write=False) - TestParticle.add_variable('p3', np.float32, to_write='once') + TestParticle = ptype[mode].add_variable('p', np.float32, initial=0.33) + TestParticle = TestParticle.add_variable('p2', np.float32, initial=1, to_write=False) + TestParticle = TestParticle.add_variable('p3', np.float32, to_write='once') pset = ParticleSet(fieldset, lon=lon, lat=lat, depth=[4]*len(lon), pclass=TestParticle, p3=np.arange(len(lon))) pfile = pset.ParticleFile(filename, outputdt=1) @@ -240,11 +234,9 @@ def test_pset_access(fieldset, mode, npart=100): @pytest.mark.parametrize('mode', ['scipy', 'jit']) def test_pset_custom_ptype(fieldset, mode, npart=100): - class TestParticle(ptype[mode]): - pass - TestParticle.add_variable('p', np.float32, initial=0.33) - TestParticle.add_variable('n', np.int32, initial=2) + TestParticle = ptype[mode].add_variable([Variable('p', np.float32, initial=0.33), + Variable('n', np.int32, initial=2)]) pset = ParticleSet(fieldset, pclass=TestParticle, lon=np.linspace(0, 1, npart), @@ -427,10 +419,7 @@ def test_from_field_exact_val(staggered_grid): FMask = Field('mask', mask, lon, lat, interp_method='cgrid_tracer') fieldset.add_field(FMask) - class SampleParticle(ptype['scipy']): - pass - - SampleParticle.add_variable('mask', initial=0) + SampleParticle = ptype['scipy'].add_variable('mask', initial=0) def SampleMask(particle, fieldset, time): particle.mask = fieldset.mask[particle] From 22538168e7cda6f0b1dc8f13848c4ebf4c28d183 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 29 Jan 2024 11:06:10 +0100 Subject: [PATCH 03/12] Also parsing *args in Particle.add_variable --- parcels/particle.py | 6 ++++++ tests/test_particlesets.py | 3 ++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/parcels/particle.py b/parcels/particle.py index 0cc3a6c32..57ddc15be 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -243,6 +243,12 @@ def add_variable(cls, var, *args, **kwargs): if isinstance(var, list): return cls.add_variables(var) if not isinstance(var, Variable): + if len(args) > 0: + kwargs['dtype'] = args[0] + if len(args) > 1: + kwargs['initial'] = args[1] + if len(args) > 2: + kwargs['to_write'] = args[2] dtype = kwargs.pop('dtype', np.float32) initial = kwargs.pop('initial', 0) to_write = kwargs.pop('to_write', True) diff --git a/tests/test_particlesets.py b/tests/test_particlesets.py index e7204856b..fcd7d34e3 100644 --- a/tests/test_particlesets.py +++ b/tests/test_particlesets.py @@ -75,7 +75,7 @@ def test_pset_create_fromparticlefile(fieldset, mode, restart, tmpdir): TestParticle = ptype[mode].add_variable('p', np.float32, initial=0.33) TestParticle = TestParticle.add_variable('p2', np.float32, initial=1, to_write=False) - TestParticle = TestParticle.add_variable('p3', np.float32, to_write='once') + TestParticle = TestParticle.add_variable('p3', np.float64, to_write='once') pset = ParticleSet(fieldset, lon=lon, lat=lat, depth=[4]*len(lon), pclass=TestParticle, p3=np.arange(len(lon))) pfile = pset.ParticleFile(filename, outputdt=1) @@ -97,6 +97,7 @@ def Kernel(particle, fieldset, time): assert np.allclose([p.id for p in pset], [p.id for p in pset_new]) pset_new.execute(Kernel, runtime=2, dt=1) assert len(pset_new) == 3*len(pset) + assert pset[0].p3.dtype == np.float64 @pytest.mark.parametrize('mode', ['scipy']) From 75c41a6988dbed355fec0d2d5868cf1e25db2734 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 29 Jan 2024 11:07:49 +0100 Subject: [PATCH 04/12] Updating some examples and tests to use add_variables --- docs/examples/documentation_stuck_particles.ipynb | 5 +---- docs/examples/example_globcurrent.py | 7 ++----- docs/examples/example_peninsula.py | 7 ++----- docs/examples/example_stommel.py | 10 +++++----- tests/test_particlefile.py | 12 +++++------- tests/test_particles.py | 12 +++++------- 6 files changed, 20 insertions(+), 33 deletions(-) diff --git a/docs/examples/documentation_stuck_particles.ipynb b/docs/examples/documentation_stuck_particles.ipynb index e21cad50d..56cd2bc44 100644 --- a/docs/examples/documentation_stuck_particles.ipynb +++ b/docs/examples/documentation_stuck_particles.ipynb @@ -82,7 +82,6 @@ " FieldSet,\n", " JITParticle,\n", " ParticleSet,\n", - " Variable,\n", " download_example_dataset,\n", ")" ] @@ -1222,9 +1221,7 @@ }, "outputs": [], "source": [ - "class LandParticle(JITParticle):\n", - " on_land = Variable(\"on_land\")\n", - "\n", + "LandParticle = JITParticle.add_variable(\"on_land\")\n", "\n", "def Sample_land(particle, fieldset, time):\n", " particle.on_land = fieldset.landmask[\n", diff --git a/docs/examples/example_globcurrent.py b/docs/examples/example_globcurrent.py index 4f53d6af0..a705bf193 100755 --- a/docs/examples/example_globcurrent.py +++ b/docs/examples/example_globcurrent.py @@ -13,7 +13,6 @@ ParticleSet, ScipyParticle, TimeExtrapolationError, - Variable, download_example_dataset, ) @@ -99,8 +98,7 @@ def test_globcurrent_time_periodic(mode, rundays): for deferred_load in [True, False]: fieldset = set_globcurrent_fieldset(time_periodic=delta(days=365), deferred_load=deferred_load) - class MyParticle(ptype[mode]): - sample_var = Variable('sample_var', initial=0.) + MyParticle = ptype[mode].add_variable('sample_var', initial=0.) pset = ParticleSet(fieldset, pclass=MyParticle, lon=25, lat=-35, time=fieldset.U.grid.time[0]) @@ -194,8 +192,7 @@ def test_globcurrent_startparticles_between_time_arrays(mode, dt, with_starttime fieldset.add_field(Field.from_netcdf(fnamesFeb, ('P', 'eastward_eulerian_current_velocity'), {'lat': 'lat', 'lon': 'lon', 'time': 'time'})) - class MyParticle(ptype[mode]): - sample_var = Variable('sample_var', initial=0.) + MyParticle = ptype[mode].add_variable('sample_var', initial=0.) def SampleP(particle, fieldset, time): particle.sample_var += fieldset.P[time, particle.depth, particle.lat, particle.lon] diff --git a/docs/examples/example_peninsula.py b/docs/examples/example_peninsula.py index 1d0184584..a96479d02 100644 --- a/docs/examples/example_peninsula.py +++ b/docs/examples/example_peninsula.py @@ -130,11 +130,8 @@ def peninsula_example(fieldset, outfile, npart, mode='jit', degree=1, # First, we define a custom Particle class to which we add a # custom variable, the initial stream function value p. # We determine the particle base class according to mode. - class MyParticle(ptype[mode]): - # JIT compilation requires a-priori knowledge of the particle - # data structure, so we define additional variables here. - p = Variable('p', dtype=np.float32, initial=0.) - p_start = Variable('p_start', dtype=np.float32, initial=0) + MyParticle = ptype[mode].add_variable([Variable('p', dtype=np.float32, initial=0.), + Variable('p_start', dtype=np.float32, initial=0)]) # Initialise particles if fieldset.U.grid.mesh == 'flat': diff --git a/docs/examples/example_stommel.py b/docs/examples/example_stommel.py index 47f9b69a7..052987638 100755 --- a/docs/examples/example_stommel.py +++ b/docs/examples/example_stommel.py @@ -105,11 +105,11 @@ def stommel_example(npart=1, mode='jit', verbose=False, method=AdvectionRK4, gri dt = delta(hours=1) outputdt = delta(days=5) - class MyParticle(ParticleClass): - p = Variable('p', dtype=np.float32, initial=0.) - p_start = Variable('p_start', dtype=np.float32, initial=0.) - next_dt = Variable('next_dt', dtype=np.float64, initial=dt.total_seconds()) - age = Variable('age', dtype=np.float32, initial=0.) + extra_vars = [Variable('p', dtype=np.float32, initial=0.), + Variable('p_start', dtype=np.float32, initial=0.), + Variable('next_dt', dtype=np.float64, initial=dt.total_seconds()), + Variable('age', dtype=np.float32, initial=0.)] + MyParticle = ParticleClass.add_variables(extra_vars) if custom_partition_function: pset = ParticleSet.from_line(fieldset, size=npart, pclass=MyParticle, repeatdt=repeatdt, diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py index 8977acb0c..11356dfd3 100644 --- a/tests/test_particlefile.py +++ b/tests/test_particlefile.py @@ -158,14 +158,12 @@ def Update_lon(particle, fieldset, time): def test_write_dtypes_pfile(fieldset, mode, tmpdir): filepath = tmpdir.join("pfile_dtypes.zarr") - dtypes = ['float32', 'float64', 'int32', 'uint32', 'int64', 'uint64'] + dtypes = [np.float32, np.float64, np.int32, np.uint32, np.int64, np.uint64] if mode == 'scipy': - dtypes.extend(['bool_', 'int8', 'uint8', 'int16', 'uint16']) + dtypes.extend([np.bool_, np.int8, np.uint8, np.int16, np.uint16]) - class MyParticle(ptype[mode]): - for d in dtypes: - # need an exec() here because we need to dynamically set the variable name - exec(f'v_{d} = Variable("v_{d}", dtype=np.{d}, initial=0.)') + extra_vars = [Variable(f'v_{d.__name__}', dtype=d, initial=0.) for d in dtypes] + MyParticle = ptype[mode].add_variables(extra_vars) pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=0, time=0) pfile = pset.ParticleFile(name=filepath, outputdt=1) @@ -173,7 +171,7 @@ class MyParticle(ptype[mode]): ds = xr.open_zarr(filepath, mask_and_scale=False) # Note masking issue at https://stackoverflow.com/questions/68460507/xarray-loading-int-data-as-float for d in dtypes: - assert ds[f'v_{d}'].dtype == d + assert ds[f'v_{d.__name__}'].dtype == d @pytest.mark.parametrize('mode', ['scipy', 'jit']) diff --git a/tests/test_particles.py b/tests/test_particles.py index 615e13293..fa8daffb6 100644 --- a/tests/test_particles.py +++ b/tests/test_particles.py @@ -30,8 +30,7 @@ def fieldset_fixture(xdim=100, ydim=100): @pytest.mark.parametrize('mode', ['scipy', 'jit']) def test_print(fieldset, mode): - class TestParticle(ptype[mode]): - p = Variable('p', to_write=True) + TestParticle = ptype[mode].add_variable('p', to_write=True) pset = ParticleSet(fieldset, pclass=TestParticle, lon=[0, 1], lat=[0, 1]) print(pset) @@ -39,11 +38,10 @@ class TestParticle(ptype[mode]): @pytest.mark.parametrize('mode', ['scipy', 'jit']) def test_variable_init(fieldset, mode, npart=10): """Test that checks correct initialisation of custom variables.""" - class TestParticle(ptype[mode]): - p_float = Variable('p_float', dtype=np.float32, initial=10.) - p_double = Variable('p_double', dtype=np.float64, initial=11.) - p_int = Variable('p_int', dtype=np.int32, initial=12.) - + extra_vars = [Variable('p_float', dtype=np.float32, initial=10.), + Variable('p_double', dtype=np.float64, initial=11.)] + TestParticle = ptype[mode].add_variables(extra_vars) + TestParticle = TestParticle.add_variable('p_int', np.int32, initial=12.) pset = ParticleSet(fieldset, pclass=TestParticle, lon=np.linspace(0, 1, npart), lat=np.linspace(1, 0, npart)) From 301d3ff4981e409be06ef388b2e152baf567d483 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jan 2024 10:10:30 +0000 Subject: [PATCH 05/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/examples/documentation_stuck_particles.ipynb | 1 + docs/examples/parcels_tutorial.ipynb | 12 ++++++++---- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/docs/examples/documentation_stuck_particles.ipynb b/docs/examples/documentation_stuck_particles.ipynb index 56cd2bc44..681748e7e 100644 --- a/docs/examples/documentation_stuck_particles.ipynb +++ b/docs/examples/documentation_stuck_particles.ipynb @@ -1223,6 +1223,7 @@ "source": [ "LandParticle = JITParticle.add_variable(\"on_land\")\n", "\n", + "\n", "def Sample_land(particle, fieldset, time):\n", " particle.on_land = fieldset.landmask[\n", " time, particle.depth, particle.lat, particle.lon\n", diff --git a/docs/examples/parcels_tutorial.ipynb b/docs/examples/parcels_tutorial.ipynb index aea33d664..d23e823f1 100644 --- a/docs/examples/parcels_tutorial.ipynb +++ b/docs/examples/parcels_tutorial.ipynb @@ -13601,7 +13601,9 @@ "metadata": {}, "outputs": [], "source": [ - "SampleParticle = JITParticle.add_variable(\"p\") # add variable p to Particle class to store the pressure field" + "SampleParticle = JITParticle.add_variable(\n", + " \"p\"\n", + ") # add variable p to Particle class to store the pressure field" ] }, { @@ -13773,9 +13775,11 @@ "metadata": {}, "outputs": [], "source": [ - "extra_vars = [Variable(\"distance\", initial=0.0, dtype=np.float32),\n", - " Variable(\"prev_lon\", dtype=np.float32, to_write=False, initial=attrgetter(\"lon\")),\n", - " Variable(\"prev_lat\", dtype=np.float32, to_write=False, initial=attrgetter(\"lat\"))]\n", + "extra_vars = [\n", + " Variable(\"distance\", initial=0.0, dtype=np.float32),\n", + " Variable(\"prev_lon\", dtype=np.float32, to_write=False, initial=attrgetter(\"lon\")),\n", + " Variable(\"prev_lat\", dtype=np.float32, to_write=False, initial=attrgetter(\"lat\")),\n", + "]\n", "\n", "DistParticle = JITParticle.add_variables(extra_vars)" ] From c5e3d7553af3a96d22988bb7cc8b342e631469b6 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 29 Jan 2024 11:11:55 +0100 Subject: [PATCH 06/12] Fixing pre-commit error --- parcels/particle.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/parcels/particle.py b/parcels/particle.py index 57ddc15be..942a4371e 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -239,7 +239,6 @@ def add_variable(cls, var, *args, **kwargs): Variable object to be added. Can be the name of the Variable, a Variable object, or a list of Variable objects """ - if isinstance(var, list): return cls.add_variables(var) if not isinstance(var, Variable): @@ -269,7 +268,6 @@ def add_variables(cls, variables): variables : list of Variable Variable objects to be added. Has to be a list of Variable objects """ - NewParticle = cls for var in variables: NewParticle = NewParticle.add_variable(var) From 29ace74c3b7fd84354dad5cd9a4b1216b10ff808 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Mon, 29 Jan 2024 14:25:36 +0100 Subject: [PATCH 07/12] Updating more tutorials to use Particle.add_variables() --- docs/examples/tutorial_Argofloats.ipynb | 12 ++++++------ docs/examples/tutorial_NestedFields.ipynb | 8 +------- docs/examples/tutorial_analyticaladvection.ipynb | 7 ++++--- docs/examples/tutorial_delaystart.ipynb | 10 +++++----- docs/examples/tutorial_interaction.ipynb | 13 +++++++------ docs/examples/tutorial_interpolation.ipynb | 6 ++---- docs/examples/tutorial_parcels_structure.ipynb | 8 ++------ .../tutorial_particle_field_interaction.ipynb | 7 +------ docs/examples/tutorial_sampling.ipynb | 4 +--- 9 files changed, 29 insertions(+), 46 deletions(-) diff --git a/docs/examples/tutorial_Argofloats.ipynb b/docs/examples/tutorial_Argofloats.ipynb index 3a458c735..a0964edc0 100644 --- a/docs/examples/tutorial_Argofloats.ipynb +++ b/docs/examples/tutorial_Argofloats.ipynb @@ -128,19 +128,19 @@ "\n", "\n", "# Define a new Particle type including extra Variables\n", - "class ArgoParticle(JITParticle):\n", + "ArgoParticle = JITParticle.add_variables([\n", " # Phase of cycle:\n", " # init_descend=0,\n", " # drift=1,\n", " # profile_descend=2,\n", " # profile_ascend=3,\n", " # transmit=4\n", - " cycle_phase = Variable(\"cycle_phase\", dtype=np.int32, initial=0.0)\n", - " cycle_age = Variable(\"cycle_age\", dtype=np.float32, initial=0.0)\n", - " drift_age = Variable(\"drift_age\", dtype=np.float32, initial=0.0)\n", + " Variable(\"cycle_phase\", dtype=np.int32, initial=0.0),\n", + " Variable(\"cycle_age\", dtype=np.float32, initial=0.0),\n", + " Variable(\"drift_age\", dtype=np.float32, initial=0.0),\n", " # if fieldset has temperature\n", - " # temp = Variable('temp', dtype=np.float32, initial=np.nan)\n", - "\n", + " # Variable('temp', dtype=np.float32, initial=np.nan),\n", + "])\n", "\n", "# Initiate one Argo float in the Agulhas Current\n", "pset = ParticleSet(\n", diff --git a/docs/examples/tutorial_NestedFields.ipynb b/docs/examples/tutorial_NestedFields.ipynb index e9b1983c0..760d0b76e 100644 --- a/docs/examples/tutorial_NestedFields.ipynb +++ b/docs/examples/tutorial_NestedFields.ipynb @@ -226,16 +226,10 @@ } ], "source": [ - "from parcels import Variable\n", - "\n", - "\n", "def SampleNestedFieldIndex(particle, fieldset, time):\n", " particle.f = fieldset.F[time, particle.depth, particle.lat, particle.lon]\n", "\n", - "\n", - "class SampleParticle(JITParticle):\n", - " f = Variable(\"f\", dtype=np.int32)\n", - "\n", + "SampleParticle = JITParticle.add_variable(\"f\", dtype=np.int32)\n", "\n", "pset = ParticleSet(fieldset, pclass=SampleParticle, lon=[1000], lat=[500])\n", "pset.execute(SampleNestedFieldIndex, runtime=1)\n", diff --git a/docs/examples/tutorial_analyticaladvection.ipynb b/docs/examples/tutorial_analyticaladvection.ipynb index 8c25d8461..72aec5f80 100644 --- a/docs/examples/tutorial_analyticaladvection.ipynb +++ b/docs/examples/tutorial_analyticaladvection.ipynb @@ -154,9 +154,10 @@ " particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon]\n", "\n", "\n", - "class MyParticle(ScipyParticle):\n", - " radius = Variable(\"radius\", dtype=np.float32, initial=0.0)\n", - " radius_start = Variable(\"radius_start\", dtype=np.float32, initial=0.0)\n", + "MyParticle = ScipyParticle.add_variables([\n", + " Variable(\"radius\", dtype=np.float32, initial=0.0),\n", + " Variable(\"radius_start\", dtype=np.float32, initial=0.0)\n", + "])\n", "\n", "\n", "pset = ParticleSet(fieldsetRR, pclass=MyParticle, lon=0, lat=4e3, time=0)\n", diff --git a/docs/examples/tutorial_delaystart.ipynb b/docs/examples/tutorial_delaystart.ipynb index c4f3cc725..d6b4abf18 100644 --- a/docs/examples/tutorial_delaystart.ipynb +++ b/docs/examples/tutorial_delaystart.ipynb @@ -32416,11 +32416,11 @@ } ], "source": [ - "class GrowingParticle(JITParticle):\n", - " mass = Variable(\"mass\", initial=0)\n", - " splittime = Variable(\"splittime\", initial=-1)\n", - " splitmass = Variable(\"splitmass\", initial=0)\n", - "\n", + "GrowingParticle = JITParticle.add_variables([\n", + " Variable(\"mass\", initial=0),\n", + " Variable(\"splittime\", initial=-1),\n", + " Variable(\"splitmass\", initial=0),\n", + "])\n", "\n", "def GrowParticles(particle, fieldset, time):\n", " # 25% chance per timestep for particle to grow\n", diff --git a/docs/examples/tutorial_interaction.ipynb b/docs/examples/tutorial_interaction.ipynb index 23faff099..1d0108cec 100644 --- a/docs/examples/tutorial_interaction.ipynb +++ b/docs/examples/tutorial_interaction.ipynb @@ -152,8 +152,9 @@ "\n", "# Create custom particle class with extra variable that indicates\n", "# whether the interaction kernel should be executed on this particle.\n", - "class InteractingParticle(ScipyParticle):\n", - " attractor = Variable(\"attractor\", dtype=np.bool_, to_write=\"once\")\n", + "InteractingParticle = ScipyParticle.add_variable(\"attractor\",\n", + " dtype=np.bool_,\n", + " to_write=\"once\")\n", "\n", "\n", "attractor = [\n", @@ -30838,10 +30839,10 @@ "\n", "# Create custom InteractionParticle class\n", "# with extra variables nearest_neighbor and mass\n", - "class MergeParticle(ScipyInteractionParticle):\n", - " nearest_neighbor = Variable(\"nearest_neighbor\", dtype=np.int64, to_write=False)\n", - " mass = Variable(\"mass\", initial=1, dtype=np.float32)\n", - "\n", + "MergeParticle = ScipyInteractionParticle.add_variables([\n", + " Variable(\"nearest_neighbor\", dtype=np.int64, to_write=False),\n", + " Variable(\"mass\", initial=1, dtype=np.float32),\n", + "])\n", "\n", "pset = ParticleSet(\n", " fieldset=fieldset,\n", diff --git a/docs/examples/tutorial_interpolation.ipynb b/docs/examples/tutorial_interpolation.ipynb index b23666a0c..925b28f93 100644 --- a/docs/examples/tutorial_interpolation.ipynb +++ b/docs/examples/tutorial_interpolation.ipynb @@ -28,7 +28,7 @@ "import numpy as np\n", "from matplotlib import cm\n", "\n", - "from parcels import FieldSet, JITParticle, ParticleSet, Variable" + "from parcels import FieldSet, JITParticle, ParticleSet" ] }, { @@ -74,9 +74,7 @@ "metadata": {}, "outputs": [], "source": [ - "class SampleParticle(JITParticle):\n", - " p = Variable(\"p\", dtype=np.float32)\n", - "\n", + "SampleParticle = JITParticle.add_variable(\"p\", dtype=np.float32)\n", "\n", "def SampleP(particle, fieldset, time):\n", " particle.p = fieldset.P[time, particle.depth, particle.lat, particle.lon]" diff --git a/docs/examples/tutorial_parcels_structure.ipynb b/docs/examples/tutorial_parcels_structure.ipynb index 94ac9d9c0..1b98374be 100644 --- a/docs/examples/tutorial_parcels_structure.ipynb +++ b/docs/examples/tutorial_parcels_structure.ipynb @@ -201,12 +201,8 @@ "source": [ "from parcels import JITParticle, ParticleSet, Variable\n", "\n", - "# Define a new particle class\n", - "\n", - "\n", - "class AgeParticle(JITParticle): # It is a JIT particle\n", - " age = Variable(\"age\", initial=0) # Variable 'age' is added with initial value 0.\n", - "\n", + "# Define a new particleclass with Variable 'age' with initial value 0.\n", + "AgeParticle = JITParticle.add_variable(Variable(\"age\", initial=0))\n", "\n", "pset = ParticleSet(\n", " fieldset=fieldset, # the fields that the particleset uses\n", diff --git a/docs/examples/tutorial_particle_field_interaction.ipynb b/docs/examples/tutorial_particle_field_interaction.ipynb index 5f8dfa9cd..45a996ba3 100644 --- a/docs/examples/tutorial_particle_field_interaction.ipynb +++ b/docs/examples/tutorial_particle_field_interaction.ipynb @@ -56,7 +56,6 @@ " FieldSet,\n", " ParticleSet,\n", " ScipyParticle,\n", - " Variable,\n", " download_example_dataset,\n", ")" ] @@ -160,11 +159,7 @@ "metadata": {}, "outputs": [], "source": [ - "class VectorParticle(ScipyParticle):\n", - " \"\"\"initialise particle concentration c with a non-zero value\"\"\"\n", - "\n", - " c = Variable(\"c\", dtype=np.float32, initial=100.0)\n", - "\n", + "VectorParticle = ScipyParticle.add_variable(\"c\", dtype=np.float32, initial=100.0)\n", "\n", "def Interaction(particle, fieldset, time):\n", " \"\"\"define the interaction between the particle and the fieldset.C field.\n", diff --git a/docs/examples/tutorial_sampling.ipynb b/docs/examples/tutorial_sampling.ipynb index d702feac2..19c735976 100644 --- a/docs/examples/tutorial_sampling.ipynb +++ b/docs/examples/tutorial_sampling.ipynb @@ -125,9 +125,7 @@ "metadata": {}, "outputs": [], "source": [ - "class SampleParticle(JITParticle): # Define a new particle class\n", - " temperature = Variable(\"temperature\")\n", - "\n", + "SampleParticle = JITParticle.add_variable(\"temperature\")\n", "\n", "pset = ParticleSet(\n", " fieldset=fieldset, pclass=SampleParticle, lon=lon, lat=lat, time=time\n", From c0baa1819606b49a35b024c4914971f80e210bd6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 29 Jan 2024 13:26:03 +0000 Subject: [PATCH 08/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/examples/tutorial_Argofloats.ipynb | 28 ++++++++++--------- docs/examples/tutorial_NestedFields.ipynb | 1 + .../tutorial_analyticaladvection.ipynb | 10 ++++--- docs/examples/tutorial_delaystart.ipynb | 13 +++++---- docs/examples/tutorial_interaction.ipynb | 16 ++++++----- docs/examples/tutorial_interpolation.ipynb | 1 + .../tutorial_particle_field_interaction.ipynb | 1 + 7 files changed, 41 insertions(+), 29 deletions(-) diff --git a/docs/examples/tutorial_Argofloats.ipynb b/docs/examples/tutorial_Argofloats.ipynb index a0964edc0..f672e604b 100644 --- a/docs/examples/tutorial_Argofloats.ipynb +++ b/docs/examples/tutorial_Argofloats.ipynb @@ -128,19 +128,21 @@ "\n", "\n", "# Define a new Particle type including extra Variables\n", - "ArgoParticle = JITParticle.add_variables([\n", - " # Phase of cycle:\n", - " # init_descend=0,\n", - " # drift=1,\n", - " # profile_descend=2,\n", - " # profile_ascend=3,\n", - " # transmit=4\n", - " Variable(\"cycle_phase\", dtype=np.int32, initial=0.0),\n", - " Variable(\"cycle_age\", dtype=np.float32, initial=0.0),\n", - " Variable(\"drift_age\", dtype=np.float32, initial=0.0),\n", - " # if fieldset has temperature\n", - " # Variable('temp', dtype=np.float32, initial=np.nan),\n", - "])\n", + "ArgoParticle = JITParticle.add_variables(\n", + " [\n", + " # Phase of cycle:\n", + " # init_descend=0,\n", + " # drift=1,\n", + " # profile_descend=2,\n", + " # profile_ascend=3,\n", + " # transmit=4\n", + " Variable(\"cycle_phase\", dtype=np.int32, initial=0.0),\n", + " Variable(\"cycle_age\", dtype=np.float32, initial=0.0),\n", + " Variable(\"drift_age\", dtype=np.float32, initial=0.0),\n", + " # if fieldset has temperature\n", + " # Variable('temp', dtype=np.float32, initial=np.nan),\n", + " ]\n", + ")\n", "\n", "# Initiate one Argo float in the Agulhas Current\n", "pset = ParticleSet(\n", diff --git a/docs/examples/tutorial_NestedFields.ipynb b/docs/examples/tutorial_NestedFields.ipynb index 760d0b76e..664e01365 100644 --- a/docs/examples/tutorial_NestedFields.ipynb +++ b/docs/examples/tutorial_NestedFields.ipynb @@ -229,6 +229,7 @@ "def SampleNestedFieldIndex(particle, fieldset, time):\n", " particle.f = fieldset.F[time, particle.depth, particle.lat, particle.lon]\n", "\n", + "\n", "SampleParticle = JITParticle.add_variable(\"f\", dtype=np.int32)\n", "\n", "pset = ParticleSet(fieldset, pclass=SampleParticle, lon=[1000], lat=[500])\n", diff --git a/docs/examples/tutorial_analyticaladvection.ipynb b/docs/examples/tutorial_analyticaladvection.ipynb index 72aec5f80..b3b67496f 100644 --- a/docs/examples/tutorial_analyticaladvection.ipynb +++ b/docs/examples/tutorial_analyticaladvection.ipynb @@ -154,10 +154,12 @@ " particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon]\n", "\n", "\n", - "MyParticle = ScipyParticle.add_variables([\n", - " Variable(\"radius\", dtype=np.float32, initial=0.0),\n", - " Variable(\"radius_start\", dtype=np.float32, initial=0.0)\n", - "])\n", + "MyParticle = ScipyParticle.add_variables(\n", + " [\n", + " Variable(\"radius\", dtype=np.float32, initial=0.0),\n", + " Variable(\"radius_start\", dtype=np.float32, initial=0.0),\n", + " ]\n", + ")\n", "\n", "\n", "pset = ParticleSet(fieldsetRR, pclass=MyParticle, lon=0, lat=4e3, time=0)\n", diff --git a/docs/examples/tutorial_delaystart.ipynb b/docs/examples/tutorial_delaystart.ipynb index d6b4abf18..68d0bce87 100644 --- a/docs/examples/tutorial_delaystart.ipynb +++ b/docs/examples/tutorial_delaystart.ipynb @@ -32416,11 +32416,14 @@ } ], "source": [ - "GrowingParticle = JITParticle.add_variables([\n", - " Variable(\"mass\", initial=0),\n", - " Variable(\"splittime\", initial=-1),\n", - " Variable(\"splitmass\", initial=0),\n", - "])\n", + "GrowingParticle = JITParticle.add_variables(\n", + " [\n", + " Variable(\"mass\", initial=0),\n", + " Variable(\"splittime\", initial=-1),\n", + " Variable(\"splitmass\", initial=0),\n", + " ]\n", + ")\n", + "\n", "\n", "def GrowParticles(particle, fieldset, time):\n", " # 25% chance per timestep for particle to grow\n", diff --git a/docs/examples/tutorial_interaction.ipynb b/docs/examples/tutorial_interaction.ipynb index 1d0108cec..2d6d09f81 100644 --- a/docs/examples/tutorial_interaction.ipynb +++ b/docs/examples/tutorial_interaction.ipynb @@ -152,9 +152,9 @@ "\n", "# Create custom particle class with extra variable that indicates\n", "# whether the interaction kernel should be executed on this particle.\n", - "InteractingParticle = ScipyParticle.add_variable(\"attractor\",\n", - " dtype=np.bool_,\n", - " to_write=\"once\")\n", + "InteractingParticle = ScipyParticle.add_variable(\n", + " \"attractor\", dtype=np.bool_, to_write=\"once\"\n", + ")\n", "\n", "\n", "attractor = [\n", @@ -30839,10 +30839,12 @@ "\n", "# Create custom InteractionParticle class\n", "# with extra variables nearest_neighbor and mass\n", - "MergeParticle = ScipyInteractionParticle.add_variables([\n", - " Variable(\"nearest_neighbor\", dtype=np.int64, to_write=False),\n", - " Variable(\"mass\", initial=1, dtype=np.float32),\n", - "])\n", + "MergeParticle = ScipyInteractionParticle.add_variables(\n", + " [\n", + " Variable(\"nearest_neighbor\", dtype=np.int64, to_write=False),\n", + " Variable(\"mass\", initial=1, dtype=np.float32),\n", + " ]\n", + ")\n", "\n", "pset = ParticleSet(\n", " fieldset=fieldset,\n", diff --git a/docs/examples/tutorial_interpolation.ipynb b/docs/examples/tutorial_interpolation.ipynb index 925b28f93..421bfc830 100644 --- a/docs/examples/tutorial_interpolation.ipynb +++ b/docs/examples/tutorial_interpolation.ipynb @@ -76,6 +76,7 @@ "source": [ "SampleParticle = JITParticle.add_variable(\"p\", dtype=np.float32)\n", "\n", + "\n", "def SampleP(particle, fieldset, time):\n", " particle.p = fieldset.P[time, particle.depth, particle.lat, particle.lon]" ] diff --git a/docs/examples/tutorial_particle_field_interaction.ipynb b/docs/examples/tutorial_particle_field_interaction.ipynb index 45a996ba3..143837ba9 100644 --- a/docs/examples/tutorial_particle_field_interaction.ipynb +++ b/docs/examples/tutorial_particle_field_interaction.ipynb @@ -161,6 +161,7 @@ "source": [ "VectorParticle = ScipyParticle.add_variable(\"c\", dtype=np.float32, initial=100.0)\n", "\n", + "\n", "def Interaction(particle, fieldset, time):\n", " \"\"\"define the interaction between the particle and the fieldset.C field.\n", " the exchange is obtained as a discretized mass transfer equation,\n", From 7c9d10ce33de27208d8ef8c6e2fb25ece70827c8 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Tue, 30 Jan 2024 11:15:07 +0100 Subject: [PATCH 09/12] Final set of changes to tests and examples with add_variables() --- .../documentation_unstuck_Agrid.ipynb | 11 ++- docs/examples/parcels_tutorial.ipynb | 4 +- docs/examples/tutorial_sampling.ipynb | 18 ++--- parcels/particle.py | 6 +- tests/test_advection.py | 10 +-- tests/test_diffusion.py | 4 +- tests/test_fieldset.py | 26 +++--- tests/test_fieldset_sampling.py | 15 ++-- tests/test_grids.py | 81 +++++++++---------- tests/test_interaction.py | 13 +-- tests/test_kernel_execution.py | 4 +- tests/test_kernel_language.py | 57 +++++-------- tests/test_particlefile.py | 20 ++--- tests/test_particles.py | 19 ++--- tests/test_particlesets.py | 5 +- 15 files changed, 127 insertions(+), 166 deletions(-) diff --git a/docs/examples/documentation_unstuck_Agrid.ipynb b/docs/examples/documentation_unstuck_Agrid.ipynb index 7aa21def6..dc63f8bab 100644 --- a/docs/examples/documentation_unstuck_Agrid.ipynb +++ b/docs/examples/documentation_unstuck_Agrid.ipynb @@ -809,10 +809,13 @@ "metadata": {}, "outputs": [], "source": [ - "class DisplacementParticle(JITParticle):\n", - " dU = Variable(\"dU\")\n", - " dV = Variable(\"dV\")\n", - " d2s = Variable(\"d2s\", initial=1e3)\n", + "DisplacementParticle = JITParticle.add_variables(\n", + " [\n", + " Variable(\"dU\"),\n", + " Variable(\"dV\"),\n", + " Variable(\"d2s\", initial=1e3),\n", + " ]\n", + ")\n", "\n", "\n", "def set_displacement(particle, fieldset, time):\n", diff --git a/docs/examples/parcels_tutorial.ipynb b/docs/examples/parcels_tutorial.ipynb index d23e823f1..947a00094 100644 --- a/docs/examples/parcels_tutorial.ipynb +++ b/docs/examples/parcels_tutorial.ipynb @@ -13601,9 +13601,7 @@ "metadata": {}, "outputs": [], "source": [ - "SampleParticle = JITParticle.add_variable(\n", - " \"p\"\n", - ") # add variable p to Particle class to store the pressure field" + "SampleParticle = JITParticle.add_variable(\"p\")" ] }, { diff --git a/docs/examples/tutorial_sampling.ipynb b/docs/examples/tutorial_sampling.ipynb index 19c735976..609e0b33a 100644 --- a/docs/examples/tutorial_sampling.ipynb +++ b/docs/examples/tutorial_sampling.ipynb @@ -305,9 +305,12 @@ "metadata": {}, "outputs": [], "source": [ - "class SampleParticle(JITParticle):\n", - " U = Variable(\"U\", dtype=np.float32, initial=np.nan)\n", - " V = Variable(\"V\", dtype=np.float32, initial=np.nan)\n", + "SampleParticle = JITParticle.add_variables(\n", + " [\n", + " Variable(\"U\", dtype=np.float32, initial=np.nan),\n", + " Variable(\"V\", dtype=np.float32, initial=np.nan),\n", + " ]\n", + ")\n", "\n", "\n", "def SampleVel_correct(particle, fieldset, time):\n", @@ -338,12 +341,9 @@ "metadata": {}, "outputs": [], "source": [ - "class SampleParticleOnce(JITParticle):\n", - " \"\"\"Define a new particle class with Variable 'temperature'\n", - " initially zero and only written once\"\"\"\n", - "\n", - " temperature = Variable(\"temperature\", initial=0, to_write=\"once\")\n", - "\n", + "SampleParticleOnce = JITParticle.add_variables(\n", + " \"temperature\", initial=0, to_write=\"once\"\n", + ")\n", "\n", "pset = ParticleSet(\n", " fieldset=fieldset, pclass=SampleParticleOnce, lon=lon, lat=lat, time=time\n", diff --git a/parcels/particle.py b/parcels/particle.py index 942a4371e..52841a89a 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -283,9 +283,9 @@ def set_lonlatdepth_dtype(cls, dtype): cls.depth_nextloop.dtype = dtype -class ScipyInteractionParticle(ScipyParticle): - vert_dist = Variable("vert_dist", dtype=np.float32) - horiz_dist = Variable("horiz_dist", dtype=np.float32) +ScipyInteractionParticle = ScipyParticle.add_variables([ + Variable("vert_dist", dtype=np.float32), + Variable("horiz_dist", dtype=np.float32)]) class JITParticle(ScipyParticle): diff --git a/tests/test_advection.py b/tests/test_advection.py index 9d68ec2b2..b1b3b4c1b 100644 --- a/tests/test_advection.py +++ b/tests/test_advection.py @@ -19,7 +19,6 @@ ParticleSet, ScipyParticle, StatusCode, - Variable, ) ptype = {'scipy': ScipyParticle, 'jit': JITParticle} @@ -301,8 +300,7 @@ def test_stationary_eddy(fieldset_stationary, mode, method, rtol, diffField, npa dt = delta(minutes=3).total_seconds() endtime = delta(hours=6).total_seconds() - class RK45Particles(ptype[mode]): - next_dt = Variable('next_dt', dtype=np.float32, initial=dt) + RK45Particles = ptype[mode]('next_dt', dtype=np.float32, initial=dt) pclass = RK45Particles if method == 'RK45' else ptype[mode] pset = ParticleSet(fieldset, pclass=pclass, lon=lon, lat=lat) @@ -395,8 +393,7 @@ def test_moving_eddy(fieldset_moving, mode, method, rtol, diffField, npart=1): dt = delta(minutes=3).total_seconds() endtime = delta(hours=6).total_seconds() - class RK45Particles(ptype[mode]): - next_dt = Variable('next_dt', dtype=np.float32, initial=dt) + RK45Particles = ptype[mode]('next_dt', dtype=np.float32, initial=dt) pclass = RK45Particles if method == 'RK45' else ptype[mode] pset = ParticleSet(fieldset, pclass=pclass, lon=lon, lat=lat) @@ -461,8 +458,7 @@ def test_decaying_eddy(fieldset_decaying, mode, method, rtol, diffField, npart=1 dt = delta(minutes=3).total_seconds() endtime = delta(hours=6).total_seconds() - class RK45Particles(ptype[mode]): - next_dt = Variable('next_dt', dtype=np.float32, initial=dt) + RK45Particles = ptype[mode]('next_dt', dtype=np.float32, initial=dt) pclass = RK45Particles if method == 'RK45' else ptype[mode] pset = ParticleSet(fieldset, pclass=pclass, lon=lon, lat=lat) diff --git a/tests/test_diffusion.py b/tests/test_diffusion.py index dda28ed34..02af18253 100644 --- a/tests/test_diffusion.py +++ b/tests/test_diffusion.py @@ -15,7 +15,6 @@ ParticleSet, RectilinearZGrid, ScipyParticle, - Variable, ) ptype = {'scipy': ScipyParticle, 'jit': JITParticle} @@ -132,8 +131,7 @@ def test_randomvonmises(mode, mu, kappa, npart=10000): # Set random seed ParcelsRandom.seed(1234) - class AngleParticle(ptype[mode]): - angle = Variable('angle') + AngleParticle = ptype[mode].add_variable('angle') pset = ParticleSet(fieldset=fieldset, pclass=AngleParticle, lon=np.zeros(npart), lat=np.zeros(npart), depth=np.zeros(npart)) def vonmises(particle, fieldset, time): diff --git a/tests/test_fieldset.py b/tests/test_fieldset.py index a5ba1bf49..2f17eb09a 100644 --- a/tests/test_fieldset.py +++ b/tests/test_fieldset.py @@ -834,13 +834,14 @@ def sampleTemp(particle, fieldset, time): # test if we can sample a non-timevarying field too particle.d = fieldset.D[0, 0, particle.lat, particle.lon] - class MyParticle(ptype[mode]): - temp = Variable('temp', dtype=np.float32, initial=20.) - u1 = Variable('u1', dtype=np.float32, initial=0.) - u2 = Variable('u2', dtype=np.float32, initial=0.) - v1 = Variable('v1', dtype=np.float32, initial=0.) - v2 = Variable('v2', dtype=np.float32, initial=0.) - d = Variable('d', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variables([ + Variable('temp', dtype=np.float32, initial=20.), + Variable('u1', dtype=np.float32, initial=0.), + Variable('u2', dtype=np.float32, initial=0.), + Variable('v1', dtype=np.float32, initial=0.), + Variable('v2', dtype=np.float32, initial=0.), + Variable('d', dtype=np.float32, initial=0.), + ]) pset = ParticleSet.from_list(fieldset, pclass=MyParticle, lon=[0.5], lat=[0.5], depth=[0.5]) pset.execute(AdvectionRK4_3D + pset.Kernel(sampleTemp), runtime=delta(hours=51), dt=delta(hours=dt_sign*1)) @@ -944,10 +945,10 @@ def test_fieldset_initialisation_kernel_dask(time2, tmpdir, filename='test_parce def SampleField(particle, fieldset, time): particle.u_kernel, particle.v_kernel = fieldset.UV[time, particle.depth, particle.lat, particle.lon] - class SampleParticle(JITParticle): - u_kernel = Variable('u_kernel', dtype=np.float32, initial=0.) - v_kernel = Variable('v_kernel', dtype=np.float32, initial=0.) - u_scipy = Variable('u_scipy', dtype=np.float32, initial=0.) + SampleParticle = JITParticle.add_variables([ + Variable('u_kernel', dtype=np.float32, initial=0.), + Variable('v_kernel', dtype=np.float32, initial=0.), + Variable('u_scipy', dtype=np.float32, initial=0.)]) pset = ParticleSet(fieldset, pclass=SampleParticle, time=[0, time2], lon=[0.5, 0.5], lat=[0.5, 0.5], depth=[0.5, 0.5]) @@ -1084,8 +1085,7 @@ def test_deferredload_simplefield(mode, direction, time_extrapolation, tmpdir, t fieldset = FieldSet.from_netcdf(filename, {'U': 'U', 'V': 'V'}, {'lon': 'x', 'lat': 'y', 'time': 't'}, deferred_load=True, mesh='flat', allow_time_extrapolation=time_extrapolation) - class SamplingParticle(ptype[mode]): - p = Variable('p') + SamplingParticle = ptype[mode].add_variable("p") pset = ParticleSet(fieldset, SamplingParticle, lon=0.5, lat=0.5) def SampleU(particle, fieldset, time): diff --git a/tests/test_fieldset_sampling.py b/tests/test_fieldset_sampling.py index 68aaca45f..8ba473dc6 100644 --- a/tests/test_fieldset_sampling.py +++ b/tests/test_fieldset_sampling.py @@ -23,11 +23,10 @@ def pclass(mode): - class SampleParticle(ptype[mode]): - u = Variable('u', dtype=np.float32) - v = Variable('v', dtype=np.float32) - p = Variable('p', dtype=np.float32) - return SampleParticle + return ptype[mode].add_variables([ + Variable('u', dtype=np.float32), + Variable('v', dtype=np.float32), + Variable('p', dtype=np.float32)]) def k_sample_uv(): @@ -628,8 +627,7 @@ def test_sampling_multigrids_non_vectorfield_from_file(mode, npart, tmpdir, chs, assert fieldset.U.grid is fieldset.V.grid assert fieldset.U.grid is not fieldset.B.grid - class TestParticle(ptype[mode]): - sample_var = Variable('sample_var', initial=0.) + TestParticle = ptype[mode].add_variable('sample_var', initial=0.) pset = ParticleSet.from_line(fieldset, pclass=TestParticle, start=[0.3, 0.3], finish=[0.7, 0.7], size=npart) @@ -672,8 +670,7 @@ def test_sampling_multigrids_non_vectorfield(mode, npart): assert fieldset.U.grid is fieldset.V.grid assert fieldset.U.grid is not fieldset.B.grid - class TestParticle(ptype[mode]): - sample_var = Variable('sample_var', initial=0.) + TestParticle = ptype[mode].add_variable('sample_var', initial=0.) pset = ParticleSet.from_line(fieldset, pclass=TestParticle, start=[0.3, 0.3], finish=[0.7, 0.7], size=npart) diff --git a/tests/test_grids.py b/tests/test_grids.py index a317cbbce..76a11e0ea 100644 --- a/tests/test_grids.py +++ b/tests/test_grids.py @@ -84,9 +84,9 @@ def sampleTemp(particle, fieldset, time): particle.temp0 = fieldset.temp0[time+particle.dt, particle.depth, particle.lat, particle.lon] particle.temp1 = fieldset.temp1[time+particle.dt, particle.depth, particle.lat, particle.lon] - class MyParticle(ptype[mode]): - temp0 = Variable('temp0', dtype=np.float32, initial=20.) - temp1 = Variable('temp1', dtype=np.float32, initial=20.) + MyParticle = ptype[mode].add_variables([ + Variable('temp0', dtype=np.float32, initial=20.), + Variable('temp1', dtype=np.float32, initial=20.)]) pset = ParticleSet.from_list(fieldset, MyParticle, lon=[3001], lat=[5001], repeatdt=1) @@ -224,8 +224,7 @@ def bath_func(lon): def sampleTemp(particle, fieldset, time): particle.temp = fieldset.temp[time, particle.depth, particle.lat, particle.lon] - class MyParticle(ptype[mode]): - temp = Variable('temp', dtype=np.float32, initial=20.) + MyParticle = ptype[mode].add_variable('temp', dtype=np.float32, initial=20.) lon = 400 lat = 0 @@ -309,8 +308,7 @@ def bath_func(lon): rel_depth_field = Field('relDepth', rel_depth_data, grid=grid) fieldset = FieldSet(u_field, v_field, fields={'relDepth': rel_depth_field}) - class MyParticle(ptype[mode]): - relDepth = Variable('relDepth', dtype=np.float32, initial=20.) + MyParticle = ptype[mode].add_variable('relDepth', dtype=np.float32, initial=20.) def moveEast(particle, fieldset, time): particle_dlon += 5 * particle.dt # noqa @@ -352,8 +350,7 @@ def sampleSpeed(particle, fieldset, time): u, v = fieldset.UV[time, particle.depth, particle.lat, particle.lon] particle.speed = math.sqrt(u*u+v*v) - class MyParticle(ptype[mode]): - speed = Variable('speed', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variable('speed', dtype=np.float32, initial=0.) pset = ParticleSet.from_list(fieldset, MyParticle, lon=[400, -200], lat=[600, 600]) pset.execute(pset.Kernel(sampleSpeed), runtime=1) @@ -380,9 +377,9 @@ def test_nemo_grid(mode): def sampleVel(particle, fieldset, time): (particle.zonal, particle.meridional) = fieldset.UV[time, particle.depth, particle.lat, particle.lon] - class MyParticle(ptype[mode]): - zonal = Variable('zonal', dtype=np.float32, initial=0.) - meridional = Variable('meridional', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variables([ + Variable('zonal', dtype=np.float32, initial=0.), + Variable('meridional', dtype=np.float32, initial=0.)]) lonp = 175.5 latp = 81.5 @@ -437,9 +434,9 @@ def test_cgrid_uniform_2dvel(mode, time): def sampleVel(particle, fieldset, time): (particle.zonal, particle.meridional) = fieldset.UV[time, particle.depth, particle.lat, particle.lon] - class MyParticle(ptype[mode]): - zonal = Variable('zonal', dtype=np.float32, initial=0.) - meridional = Variable('meridional', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variables([ + Variable('zonal', dtype=np.float32, initial=0.), + Variable('meridional', dtype=np.float32, initial=0.)]) pset = ParticleSet.from_list(fieldset, MyParticle, lon=.7, lat=.3) pset.execute(pset.Kernel(sampleVel), runtime=1) @@ -497,10 +494,10 @@ def test_cgrid_uniform_3dvel(mode, vert_mode, time): def sampleVel(particle, fieldset, time): (particle.zonal, particle.meridional, particle.vertical) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon] - class MyParticle(ptype[mode]): - zonal = Variable('zonal', dtype=np.float32, initial=0.) - meridional = Variable('meridional', dtype=np.float32, initial=0.) - vertical = Variable('vertical', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variables([ + Variable('zonal', dtype=np.float32, initial=0.), + Variable('meridional', dtype=np.float32, initial=0.), + Variable('vertical', dtype=np.float32, initial=0.)]) pset = ParticleSet.from_list(fieldset, MyParticle, lon=.7, lat=.3, depth=.2) pset.execute(pset.Kernel(sampleVel), runtime=1) @@ -554,10 +551,10 @@ def test_cgrid_uniform_3dvel_spherical(mode, vert_mode, time): def sampleVel(particle, fieldset, time): (particle.zonal, particle.meridional, particle.vertical) = fieldset.UVW[time, particle.depth, particle.lat, particle.lon] - class MyParticle(ptype[mode]): - zonal = Variable('zonal', dtype=np.float32, initial=0.) - meridional = Variable('meridional', dtype=np.float32, initial=0.) - vertical = Variable('vertical', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variables([ + Variable('zonal', dtype=np.float32, initial=0.), + Variable('meridional', dtype=np.float32, initial=0.), + Variable('vertical', dtype=np.float32, initial=0.)]) lonp = 179.8 latp = 81.35 @@ -601,12 +598,12 @@ def OutBoundsError(particle, fieldset, time): particle_ddepth -= 3 # noqa particle.state = StatusCode.Success - class MyParticle(ptype[mode]): - zonal = Variable('zonal', dtype=np.float32, initial=0.) - meridional = Variable('meridional', dtype=np.float32, initial=0.) - vert = Variable('vert', dtype=np.float32, initial=0.) - tracer = Variable('tracer', dtype=np.float32, initial=0.) - out_of_bounds = Variable('out_of_bounds', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variables([ + Variable('zonal', dtype=np.float32, initial=0.), + Variable('meridional', dtype=np.float32, initial=0.), + Variable('vert', dtype=np.float32, initial=0.), + Variable('tracer', dtype=np.float32, initial=0.), + Variable('out_of_bounds', dtype=np.float32, initial=0.)]) pset = ParticleSet.from_list(fieldset, MyParticle, lon=[3, 5, 1], lat=[3, 5, 1], depth=[3, 7, 11]) pset.execute(pset.Kernel(sampleVel) + OutBoundsError, runtime=1) @@ -699,9 +696,9 @@ def UpdateR(particle, fieldset, time): particle.radius_start = fieldset.R[time, particle.depth, particle.lat, particle.lon] particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon] - class MyParticle(ptype[mode]): - radius = Variable('radius', dtype=np.float32, initial=0.) - radius_start = Variable('radius_start', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variables([ + Variable('radius', dtype=np.float32, initial=0.), + Variable('radius_start', dtype=np.float32, initial=0.)]) pset = ParticleSet(fieldset, pclass=MyParticle, lon=0, lat=4e3, time=0) @@ -777,9 +774,9 @@ def UpdateR(particle, fieldset, time): particle.radius_start = fieldset.R[time, particle.depth, particle.lat, particle.lon] particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon] - class MyParticle(ptype[mode]): - radius = Variable('radius', dtype=np.float32, initial=0.) - radius_start = Variable('radius_start', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variables([ + Variable('radius', dtype=np.float32, initial=0.), + Variable('radius_start', dtype=np.float32, initial=0.)]) pset = ParticleSet(fieldset, pclass=MyParticle, depth=4e3, lon=0, lat=0, time=0) @@ -856,9 +853,9 @@ def UpdateR(particle, fieldset, time): particle.radius_start = fieldset.R[time, particle.depth, particle.lat, particle.lon] particle.radius = fieldset.R[time, particle.depth, particle.lat, particle.lon] - class MyParticle(ptype[mode]): - radius = Variable('radius', dtype=np.float32, initial=0.) - radius_start = Variable('radius_start', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variables([ + Variable('radius', dtype=np.float32, initial=0.), + Variable('radius_start', dtype=np.float32, initial=0.)]) pset = ParticleSet(fieldset, pclass=MyParticle, depth=-9.995e3, lon=0, lat=0, time=0) @@ -924,10 +921,10 @@ def VelocityInterpolator(particle, fieldset, time): particle.Vvel = fieldset.V[time, particle.depth, particle.lat, particle.lon] particle.Wvel = fieldset.W[time, particle.depth, particle.lat, particle.lon] - class myParticle(ptype[mode]): - Uvel = Variable("Uvel", dtype=np.float32, initial=0.0) - Vvel = Variable("Vvel", dtype=np.float32, initial=0.0) - Wvel = Variable("Wvel", dtype=np.float32, initial=0.0) + myParticle = ptype[mode].add_variables([ + Variable("Uvel", dtype=np.float32, initial=0.0), + Variable("Vvel", dtype=np.float32, initial=0.0), + Variable("Wvel", dtype=np.float32, initial=0.0)]) for pointtype in ["U", "V", "W"]: if gridindexingtype == "pop": diff --git a/tests/test_interaction.py b/tests/test_interaction.py index 2b1f6f286..11cc8029b 100644 --- a/tests/test_interaction.py +++ b/tests/test_interaction.py @@ -52,11 +52,6 @@ def fieldset(xdim=20, ydim=20, mesh='spherical'): return FieldSet.from_data(data, dimensions, mesh=mesh) -class MergeParticle(ScipyInteractionParticle): - nearest_neighbor = Variable('nearest_neighbor', dtype=np.int64, to_write=False) - mass = Variable('mass', initial=1, dtype=np.float32) - - @pytest.fixture(name="fieldset") def fieldset_fixture(xdim=20, ydim=20): return fieldset(xdim=xdim, ydim=ydim) @@ -133,6 +128,9 @@ def test_neighbor_merge(fieldset): lats = [0.0, 0.0, 0.0, 0.0] # Distance in meters R_earth*0.2 degrees interaction_distance = 6371000*5.5*np.pi/180 + MergeParticle = ScipyInteractionParticle.add_variables([ + Variable('nearest_neighbor', dtype=np.int64, to_write=False), + Variable('mass', initial=1, dtype=np.float32)]) pset = ParticleSet(fieldset, pclass=MergeParticle, lon=lons, lat=lats, interaction_distance=interaction_distance) pyfunc_inter = (pset.InteractionKernel(NearestNeighborWithinRange) @@ -144,16 +142,13 @@ def test_neighbor_merge(fieldset): assert len(pset) == 1 -class AttractingParticle(ScipyInteractionParticle): - attractor = Variable('attractor', dtype=np.bool_, to_write='once') - - @pytest.mark.parametrize('mode', ['scipy']) def test_asymmetric_attraction(fieldset, mode): lons = [0.0, 0.1, 0.2] lats = [0.0, 0.0, 0.0] # Distance in meters R_earth*0.2 degrees interaction_distance = 6371000*5.5*np.pi/180 + AttractingParticle = ScipyInteractionParticle.add_variable('attractor', dtype=np.bool_, to_write='once') pset = ParticleSet(fieldset, pclass=AttractingParticle, lon=lons, lat=lats, interaction_distance=interaction_distance, attractor=[True, False, False]) diff --git a/tests/test_kernel_execution.py b/tests/test_kernel_execution.py index 7644d1d63..836e46864 100644 --- a/tests/test_kernel_execution.py +++ b/tests/test_kernel_execution.py @@ -12,7 +12,6 @@ ParticleSet, ScipyParticle, StatusCode, - Variable, ) ptype = {'scipy': ScipyParticle, 'jit': JITParticle} @@ -51,8 +50,7 @@ def MoveLon_Update_dlon(particle, fieldset, time): def SampleP(particle, fieldset, time): particle.p = fieldset.U[time, particle.depth, particle.lat, particle.lon] - class SampleParticle(ptype[mode]): - p = Variable('p', dtype=np.float32, initial=0.) + SampleParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0.) MoveLon = MoveLon_Update_dlon if kernel_type == 'update_dlon' else MoveLon_Update_Lon diff --git a/tests/test_kernel_language.py b/tests/test_kernel_language.py index 4168cdb71..75f96a3e7 100644 --- a/tests/test_kernel_language.py +++ b/tests/test_kernel_language.py @@ -55,8 +55,7 @@ def fieldset_fixture(xdim=20, ydim=20): ]) def test_expression_int(mode, name, expr, result, npart=10): """Test basic arithmetic expressions.""" - class TestParticle(ptype[mode]): - p = Variable('p', dtype=np.float32) + TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0) pset = ParticleSet(fieldset(), pclass=TestParticle, lon=np.linspace(0., 1., npart), lat=np.zeros(npart) + 0.5) @@ -74,8 +73,7 @@ class TestParticle(ptype[mode]): ]) def test_expression_float(mode, name, expr, result, npart=10): """Test basic arithmetic expressions.""" - class TestParticle(ptype[mode]): - p = Variable('p', dtype=np.float32) + TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0) pset = ParticleSet(fieldset(), pclass=TestParticle, lon=np.linspace(0., 1., npart), lat=np.zeros(npart) + 0.5) @@ -99,8 +97,7 @@ class TestParticle(ptype[mode]): ]) def test_expression_bool(mode, name, expr, result, npart=10): """Test basic arithmetic expressions.""" - class TestParticle(ptype[mode]): - p = Variable('p', dtype=np.float32) + TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0) pset = ParticleSet(fieldset(), pclass=TestParticle, lon=np.linspace(0., 1., npart), lat=np.zeros(npart) + 0.5) @@ -114,8 +111,7 @@ class TestParticle(ptype[mode]): @pytest.mark.parametrize('mode', ['scipy', 'jit']) def test_while_if_break(mode): """Test while, if and break commands.""" - class TestParticle(ptype[mode]): - p = Variable('p', dtype=np.float32, initial=0.) + TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0) pset = ParticleSet(fieldset(), pclass=TestParticle, lon=[0], lat=[0]) def kernel(particle, fieldset, time): @@ -132,9 +128,9 @@ def kernel(particle, fieldset, time): @pytest.mark.parametrize('mode', ['scipy', 'jit']) def test_nested_if(mode): """Test nested if commands.""" - class TestParticle(ptype[mode]): - p0 = Variable('p0', dtype=np.int32, initial=0) - p1 = Variable('p1', dtype=np.int32, initial=1) + TestParticle = ptype[mode].add_variables([ + Variable('p0', dtype=np.int32, initial=0), + Variable('p1', dtype=np.int32, initial=1)]) pset = ParticleSet(fieldset(), pclass=TestParticle, lon=0, lat=0) def kernel(particle, fieldset, time): @@ -150,8 +146,7 @@ def kernel(particle, fieldset, time): @pytest.mark.parametrize('mode', ['scipy', 'jit']) def test_pass(mode): """Test pass commands.""" - class TestParticle(ptype[mode]): - p = Variable('p', dtype=np.int32, initial=0) + TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0) pset = ParticleSet(fieldset(), pclass=TestParticle, lon=0, lat=0) def kernel(particle, fieldset, time): @@ -209,8 +204,7 @@ def kernel_abs(particle, fieldset, time): @pytest.mark.parametrize('mode', ['scipy', 'jit']) def test_if_withfield(fieldset, mode): """Test combination of if and Field sampling commands.""" - class TestParticle(ptype[mode]): - p = Variable('p', dtype=np.float32, initial=0.) + TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0) pset = ParticleSet(fieldset, pclass=TestParticle, lon=[0], lat=[0]) def kernel(particle, fieldset, time): @@ -242,8 +236,7 @@ def kernel(particle, fieldset, time): @pytest.mark.parametrize('mode', ['scipy']) def test_print(fieldset, mode, capfd): """Test print statements.""" - class TestParticle(ptype[mode]): - p = Variable('p', dtype=np.float32, initial=0.) + TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0) pset = ParticleSet(fieldset, pclass=TestParticle, lon=[0.5], lat=[0.5]) def kernel(particle, fieldset, time): @@ -301,8 +294,7 @@ def random_series(npart, rngfunc, rngargs, mode): ]) def test_random_float(mode, rngfunc, rngargs, npart=10): """Test basic random number generation.""" - class TestParticle(ptype[mode]): - p = Variable('p', dtype=np.float32 if rngfunc == 'randint' else np.float32) + TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0) pset = ParticleSet(fieldset(), pclass=TestParticle, lon=np.linspace(0., 1., npart), lat=np.zeros(npart) + 0.5) @@ -317,9 +309,7 @@ class TestParticle(ptype[mode]): @pytest.mark.parametrize('mode', ['scipy', 'jit']) @pytest.mark.parametrize('concat', [False, True]) def test_random_kernel_concat(fieldset, mode, concat): - class TestParticle(ptype[mode]): - p = Variable('p', dtype=np.float32) - + TestParticle = ptype[mode].add_variable('p', dtype=np.float32, initial=0) pset = ParticleSet(fieldset, pclass=TestParticle, lon=0, lat=0) def RandomKernel(particle, fieldset, time): @@ -374,8 +364,7 @@ def pykernel(particle, fieldset, time): @pytest.mark.parametrize('mode', ['scipy', 'jit']) def test_dt_modif_by_kernel(mode): - class TestParticle(ptype[mode]): - age = Variable('age', dtype=np.float32) + TestParticle = ptype[mode].add_variable('age', dtype=np.float32, initial=0) pset = ParticleSet(fieldset(), pclass=TestParticle, lon=[0.5], lat=[0]) def modif_dt(particle, fieldset, time): @@ -428,8 +417,7 @@ def generate_fieldset(xdim=2, ydim=2, zdim=2, tdim=1): data, dimensions = generate_fieldset() fieldset = FieldSet.from_data(data, dimensions) - class DensParticle(ptype[mode]): - density = Variable('density', dtype=np.float32) + DensParticle = ptype[mode].add_variable('density', dtype=np.float32) pset = ParticleSet(fieldset, pclass=DensParticle, lon=5, lat=5, depth=1000) @@ -446,22 +434,20 @@ def test_EOSseawaterproperties_kernels(mode): dimensions={'lat': 0, 'lon': 0, 'depth': 0}) fieldset.add_constant('refpressure', float(0)) - class PoTempParticle(ptype[mode]): - potemp = Variable('potemp', dtype=np.float32) - pressure = Variable('pressure', dtype=np.float32, initial=10000) + PoTempParticle = ptype[mode].add_variables([ + Variable('potemp', dtype=np.float32), + Variable('pressure', dtype=np.float32, initial=10000)]) pset = ParticleSet(fieldset, pclass=PoTempParticle, lon=5, lat=5, depth=1000) pset.execute(PtempFromTemp, runtime=1) assert np.allclose(pset[0].potemp, 36.89073) - class TempParticle(ptype[mode]): - temp = Variable('temp', dtype=np.float32) - pressure = Variable('pressure', dtype=np.float32, initial=10000) + TempParticle = ptype[mode].add_variables([ + Variable('temp', dtype=np.float32), + Variable('pressure', dtype=np.float32, initial=10000)]) pset = ParticleSet(fieldset, pclass=TempParticle, lon=5, lat=5, depth=1000) pset.execute(TempFromPtemp, runtime=1) assert np.allclose(pset[0].temp, 40) - class TPressureParticle(ptype[mode]): - pressure = Variable('pressure', dtype=np.float32) pset = ParticleSet(fieldset, pclass=TempParticle, lon=5, lat=30, depth=7321.45) pset.execute(PressureFromLatDepth, runtime=1) assert np.allclose(pset[0].pressure, 7500, atol=1e-2) @@ -491,8 +477,7 @@ def generate_fieldset(p, xdim=2, ydim=2, zdim=2, tdim=1): data, dimensions = generate_fieldset(pressure) fieldset = FieldSet.from_data(data, dimensions) - class DensParticle(ptype[mode]): - density = Variable('density', dtype=np.float32) + DensParticle = ptype[mode].add_variable('density', dtype=np.float32) pset = ParticleSet(fieldset, pclass=DensParticle, lon=5, lat=5, depth=1000) diff --git a/tests/test_particlefile.py b/tests/test_particlefile.py index 11356dfd3..f87ddb9c8 100644 --- a/tests/test_particlefile.py +++ b/tests/test_particlefile.py @@ -183,9 +183,9 @@ def Update_v(particle, fieldset, time): particle.v_once += 1. particle.age += particle.dt - class MyParticle(ptype[mode]): - v_once = Variable('v_once', dtype=np.float64, initial=0., to_write='once') - age = Variable('age', dtype=np.float32, initial=0.) + MyParticle = ptype[mode].add_variables([ + Variable('v_once', dtype=np.float64, initial=0., to_write='once'), + Variable('age', dtype=np.float32, initial=0.)]) lon = np.linspace(0, 1, npart) lat = np.linspace(1, 0, npart) time = np.arange(0, npart/10., 0.1, dtype=np.float64) @@ -209,9 +209,9 @@ def test_pset_repeated_release_delayed_adding_deleting(type, fieldset, mode, rep fieldset.maxvar = maxvar pset = None - class MyParticle(ptype[mode]): - sample_var = Variable('sample_var', initial=0.) - v_once = Variable('v_once', dtype=np.float64, initial=0., to_write='once') + MyParticle = ptype[mode].add_variables([ + Variable('sample_var', initial=0.), + Variable('v_once', dtype=np.float64, initial=0., to_write='once')]) if type == 'repeatdt': pset = ParticleSet(fieldset, lon=[0], lat=[0], pclass=MyParticle, repeatdt=repeatdt) @@ -287,10 +287,10 @@ def test_write_xiyi(fieldset, mode, tmpdir): fieldset.add_field(Field(name='P', data=np.zeros((2, 20)), lon=np.linspace(0, 1, 20), lat=[0, 2])) dt = 3600 - class XiYiParticle(ptype[mode]): - pxi0 = Variable('pxi0', dtype=np.int32, initial=0.) - pxi1 = Variable('pxi1', dtype=np.int32, initial=0.) - pyi = Variable('pyi', dtype=np.int32, initial=0.) + XiYiParticle = ptype[mode].add_variables([ + Variable('pxi0', dtype=np.int32, initial=0.), + Variable('pxi1', dtype=np.int32, initial=0.), + Variable('pyi', dtype=np.int32, initial=0.)]) def Get_XiYi(particle, fieldset, time): """Kernel to sample the grid indices of the particle. diff --git a/tests/test_particles.py b/tests/test_particles.py index fa8daffb6..c63d27cc0 100644 --- a/tests/test_particles.py +++ b/tests/test_particles.py @@ -60,8 +60,7 @@ def addOne(particle, fieldset, time): @pytest.mark.parametrize('type', ['np.int8', 'mp.float', 'np.int16']) def test_variable_unsupported_dtypes(fieldset, mode, type): """Test that checks errors thrown for unsupported dtypes in JIT mode.""" - class TestParticle(ptype[mode]): - p = Variable('p', dtype=type, initial=10.) + TestParticle = ptype[mode].add_variable('p', dtype=type, initial=10.) error_thrown = False try: ParticleSet(fieldset, pclass=TestParticle, lon=[0], lat=[0]) @@ -74,8 +73,7 @@ class TestParticle(ptype[mode]): def test_variable_special_names(fieldset, mode): """Test that checks errors thrown for special names.""" for vars in ['z', 'lon']: - class TestParticle(ptype[mode]): - tmp = Variable(vars, dtype=np.float32, initial=10.) + TestParticle = ptype[mode].add_variable(vars, dtype=np.float32, initial=10.) error_thrown = False try: ParticleSet(fieldset, pclass=TestParticle, lon=[0], lat=[0]) @@ -90,14 +88,11 @@ def test_variable_init_relative(fieldset, mode, coord_type, npart=10): """Test that checks relative initialisation of custom variables.""" lonlat_type = np.float64 if coord_type == 'double' else np.float32 - class TestParticle(ptype[mode]): - p_base = Variable('p_base', dtype=lonlat_type, initial=10.) - p_relative = Variable('p_relative', dtype=lonlat_type, - initial=attrgetter('p_base')) - p_lon = Variable('p_lon', dtype=lonlat_type, - initial=attrgetter('lon')) - p_lat = Variable('p_lat', dtype=lonlat_type, - initial=attrgetter('lat')) + TestParticle = ptype[mode].add_variables([ + Variable('p_base', dtype=lonlat_type, initial=10.), + Variable('p_relative', dtype=lonlat_type, initial=attrgetter('p_base')), + Variable('p_lon', dtype=lonlat_type, initial=attrgetter('lon')), + Variable('p_lat', dtype=lonlat_type, initial=attrgetter('lat'))]) lon = np.linspace(0, 1, npart, dtype=lonlat_type) lat = np.linspace(1, 0, npart, dtype=lonlat_type) diff --git a/tests/test_particlesets.py b/tests/test_particlesets.py index fcd7d34e3..9e1d64fdf 100644 --- a/tests/test_particlesets.py +++ b/tests/test_particlesets.py @@ -57,7 +57,7 @@ def test_pset_create_list_with_customvariable(fieldset, mode, npart=100): lon = np.linspace(0, 1, npart, dtype=np.float32) lat = np.linspace(1, 0, npart, dtype=np.float32) - MyParticle = ptype[mode].add_variable(Variable("v")) + MyParticle = ptype[mode].add_variable("v") v_vals = np.arange(npart) pset = ParticleSet.from_list(fieldset, lon=lon, lat=lat, v=v_vals, pclass=MyParticle) @@ -199,8 +199,7 @@ def IncrLon(particle, fieldset, time): @pytest.mark.parametrize('mode', ['scipy', 'jit']) def test_pset_repeatdt_custominit(fieldset, mode): - class MyParticle(ptype[mode]): - sample_var = Variable('sample_var') + MyParticle = ptype[mode].add_variable('sample_var') pset = ParticleSet(fieldset, lon=0, lat=0, pclass=MyParticle, repeatdt=1, sample_var=5) From c660fbe7a48ecd6edae60cf1df32506f519fd2b0 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Tue, 30 Jan 2024 11:37:13 +0100 Subject: [PATCH 10/12] Fixing small bugs in add_variable() tests --- docs/examples/tutorial_sampling.ipynb | 2 +- tests/test_advection.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/examples/tutorial_sampling.ipynb b/docs/examples/tutorial_sampling.ipynb index 609e0b33a..dce973ee2 100644 --- a/docs/examples/tutorial_sampling.ipynb +++ b/docs/examples/tutorial_sampling.ipynb @@ -341,7 +341,7 @@ "metadata": {}, "outputs": [], "source": [ - "SampleParticleOnce = JITParticle.add_variables(\n", + "SampleParticleOnce = JITParticle.add_variable(\n", " \"temperature\", initial=0, to_write=\"once\"\n", ")\n", "\n", diff --git a/tests/test_advection.py b/tests/test_advection.py index b1b3b4c1b..7e1133845 100644 --- a/tests/test_advection.py +++ b/tests/test_advection.py @@ -300,7 +300,7 @@ def test_stationary_eddy(fieldset_stationary, mode, method, rtol, diffField, npa dt = delta(minutes=3).total_seconds() endtime = delta(hours=6).total_seconds() - RK45Particles = ptype[mode]('next_dt', dtype=np.float32, initial=dt) + RK45Particles = ptype[mode].add_variable('next_dt', dtype=np.float32, initial=dt) pclass = RK45Particles if method == 'RK45' else ptype[mode] pset = ParticleSet(fieldset, pclass=pclass, lon=lon, lat=lat) @@ -393,7 +393,7 @@ def test_moving_eddy(fieldset_moving, mode, method, rtol, diffField, npart=1): dt = delta(minutes=3).total_seconds() endtime = delta(hours=6).total_seconds() - RK45Particles = ptype[mode]('next_dt', dtype=np.float32, initial=dt) + RK45Particles = ptype[mode].add_variable('next_dt', dtype=np.float32, initial=dt) pclass = RK45Particles if method == 'RK45' else ptype[mode] pset = ParticleSet(fieldset, pclass=pclass, lon=lon, lat=lat) @@ -458,7 +458,7 @@ def test_decaying_eddy(fieldset_decaying, mode, method, rtol, diffField, npart=1 dt = delta(minutes=3).total_seconds() endtime = delta(hours=6).total_seconds() - RK45Particles = ptype[mode]('next_dt', dtype=np.float32, initial=dt) + RK45Particles = ptype[mode].add_variable('next_dt', dtype=np.float32, initial=dt) pclass = RK45Particles if method == 'RK45' else ptype[mode] pset = ParticleSet(fieldset, pclass=pclass, lon=lon, lat=lat) From b41a4f2d9de82a05a2b06028e4db0c119ae3ab74 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 30 Jan 2024 10:37:36 +0000 Subject: [PATCH 11/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/examples/tutorial_sampling.ipynb | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/examples/tutorial_sampling.ipynb b/docs/examples/tutorial_sampling.ipynb index dce973ee2..b2889f086 100644 --- a/docs/examples/tutorial_sampling.ipynb +++ b/docs/examples/tutorial_sampling.ipynb @@ -341,9 +341,7 @@ "metadata": {}, "outputs": [], "source": [ - "SampleParticleOnce = JITParticle.add_variable(\n", - " \"temperature\", initial=0, to_write=\"once\"\n", - ")\n", + "SampleParticleOnce = JITParticle.add_variable(\"temperature\", initial=0, to_write=\"once\")\n", "\n", "pset = ParticleSet(\n", " fieldset=fieldset, pclass=SampleParticleOnce, lon=lon, lat=lat, time=time\n", From 9dc55703ebe2d98e14903f28b6f01482e77699b7 Mon Sep 17 00:00:00 2001 From: Erik van Sebille Date: Tue, 30 Jan 2024 12:11:55 +0100 Subject: [PATCH 12/12] Cleaning particle.py by removing _Particle baseclass --- parcels/particle.py | 69 +++++++++++++++++---------------------------- 1 file changed, 26 insertions(+), 43 deletions(-) diff --git a/parcels/particle.py b/parcels/particle.py index 52841a89a..769dda2ab 100644 --- a/parcels/particle.py +++ b/parcels/particle.py @@ -127,43 +127,7 @@ def supported_dtypes(self): return [np.int32, np.uint32, np.int64, np.uint64, np.float32, np.double, np.float64, c_void_p] -class _Particle: - """Private base class for all particle types.""" - - lastID = 0 # class-level variable keeping track of last Particle ID used - - def __init__(self): - ptype = self.getPType() - # Explicit initialisation of all particle variables - for v in ptype.variables: - if isinstance(v.initial, attrgetter): - initial = v.initial(self) - else: - initial = v.initial - # Enforce type of initial value - if v.dtype != c_void_p: - setattr(self, v.name, v.dtype(initial)) - - # Placeholder for explicit error handling - self.exception = None - - def __del__(self): - pass # superclass is 'object', and object itself has no destructor, hence 'pass' - - @classmethod - def getPType(cls): - return ParticleType(cls) - - @classmethod - def getInitialValue(cls, ptype, name): - return next((v.initial for v in ptype.variables if v.name is name), None) - - @classmethod - def setLastID(cls, offset): - _Particle.lastID = offset - - -class ScipyParticle(_Particle): +class ScipyParticle: """Class encapsulating the basic attributes of a particle, to be executed in SciPy mode. Parameters @@ -198,6 +162,8 @@ class ScipyParticle(_Particle): dt = Variable('dt', dtype=np.float64, to_write=False) state = Variable('state', dtype=np.int32, initial=StatusCode.Evaluate, to_write=False) + lastID = 0 # class-level variable keeping track of last Particle ID used + def __init__(self, lon, lat, pid, fieldset=None, ngrids=None, depth=0., time=0., cptr=None): # Enforce default values through Variable descriptor @@ -210,14 +176,23 @@ def __init__(self, lon, lat, pid, fieldset=None, ngrids=None, depth=0., time=0., type(self).time.initial = time type(self).time_nextloop.initial = time type(self).id.initial = pid - _Particle.lastID = max(_Particle.lastID, pid) + type(self).lastID = max(type(self).lastID, pid) type(self).obs_written.initial = 0 type(self).dt.initial = None - super().__init__() + ptype = self.getPType() + # Explicit initialisation of all particle variables + for v in ptype.variables: + if isinstance(v.initial, attrgetter): + initial = v.initial(self) + else: + initial = v.initial + # Enforce type of initial value + if v.dtype != c_void_p: + setattr(self, v.name, v.dtype(initial)) def __del__(self): - super().__del__() + pass # superclass is 'object', and object itself has no destructor, hence 'pass' def __repr__(self): time_string = 'not_yet_set' if self.time is None or np.isnan(self.time) else f"{self.time:f}" @@ -242,11 +217,11 @@ def add_variable(cls, var, *args, **kwargs): if isinstance(var, list): return cls.add_variables(var) if not isinstance(var, Variable): - if len(args) > 0: + if len(args) > 0 and 'dtype' not in kwargs: kwargs['dtype'] = args[0] - if len(args) > 1: + if len(args) > 1 and 'initial' not in kwargs: kwargs['initial'] = args[1] - if len(args) > 2: + if len(args) > 2 and 'to_write' not in kwargs: kwargs['to_write'] = args[2] dtype = kwargs.pop('dtype', np.float32) initial = kwargs.pop('initial', 0) @@ -273,6 +248,10 @@ def add_variables(cls, variables): NewParticle = NewParticle.add_variable(var) return NewParticle + @classmethod + def getPType(cls): + return ParticleType(cls) + @classmethod def set_lonlatdepth_dtype(cls, dtype): cls.lon.dtype = dtype @@ -282,6 +261,10 @@ def set_lonlatdepth_dtype(cls, dtype): cls.lat_nextloop.dtype = dtype cls.depth_nextloop.dtype = dtype + @classmethod + def setLastID(cls, offset): + ScipyParticle.lastID = offset + ScipyInteractionParticle = ScipyParticle.add_variables([ Variable("vert_dist", dtype=np.float32),