Skip to content

Commit

Permalink
Ruff tardis/model (tardis-sn#2828)
Browse files Browse the repository at this point in the history
* ruff model auto safe fixes

* ruff model auto unsafe fixes

* black tardis/model/base.py
  • Loading branch information
atharva-2001 authored and CePowers committed Oct 1, 2024
1 parent 2c1a2ac commit 3b3a69b
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 33 deletions.
6 changes: 3 additions & 3 deletions tardis/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Holding information about the TARDIS model.
Holding information about the TARDIS model.
This includes information about the shell structure,
density, abundance, temperatures and dilution
This includes information about the shell structure,
density, abundance, temperatures and dilution
factor of the model used in the simulation.
"""

Expand Down
18 changes: 8 additions & 10 deletions tardis/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,22 +350,20 @@ def from_csvy(cls, config, atom_data=None, legacy_mode_enabled=False):
)

if hasattr(csvy_model_data, "columns"):
abund_names = set(
[
name
for name in csvy_model_data.columns
if is_valid_nuclide_or_elem(name)
]
)
abund_names = {
name
for name in csvy_model_data.columns
if is_valid_nuclide_or_elem(name)
}
unsupported_columns = (
set(csvy_model_data.columns)
- abund_names
- CSVY_SUPPORTED_COLUMNS
)

field_names = set(
[field["name"] for field in csvy_model_config.datatype.fields]
)
field_names = {
field["name"] for field in csvy_model_config.datatype.fields
}
assert (
set(csvy_model_data.columns) - field_names == set()
), "CSVY columns exist without field descriptions"
Expand Down
9 changes: 5 additions & 4 deletions tardis/model/geometry/radial1d.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from numba import float64
from numba.experimental import jitclass
import warnings

import numpy as np
from astropy import units as u
import warnings
from numba import float64
from numba.experimental import jitclass


class HomologousRadial1DGeometry:
Expand Down Expand Up @@ -193,7 +194,7 @@ def to_numba(self):


@jitclass(numba_geometry_spec)
class NumbaRadial1DGeometry(object):
class NumbaRadial1DGeometry:
def __init__(self, r_inner, r_outer, v_inner, v_outer):
"""
Radial 1D Geometry for the Numba mode
Expand Down
5 changes: 2 additions & 3 deletions tardis/model/geometry/tests/test_radial1d.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from astropy import units as u
import numpy as np
import numpy.testing as npt
import pytest
from astropy import units as u

from tardis.model.geometry.radial1d import HomologousRadial1DGeometry

import pytest


@pytest.fixture(scope="function")
def homologous_radial1d_geometry():
Expand Down
8 changes: 4 additions & 4 deletions tardis/model/matter/decay.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@ def __init__(self, *args, **kwargs):
kwargs.pop("time_0")
else:
time_0 = 0 * u.d
super(IsotopicMassFraction, self).__init__(*args, **kwargs)
super().__init__(*args, **kwargs)
self.time_0 = time_0

@property
def _constructor(self):
return IsotopicMassFraction

def _update_inventory(self):
self.comp_dicts = [dict() for i in range(len(self.columns))]
self.comp_dicts = [{} for i in range(len(self.columns))]
for (atomic_number, mass_number), mass_fractions in self.iterrows():
nuclear_symbol = f"{Z_to_elem(atomic_number)}{mass_number}"
for i in range(len(self.columns)):
self.comp_dicts[i][nuclear_symbol] = mass_fractions[i]

@classmethod
def from_inventories(cls, inventories):
multi_index_tuples = set([])
multi_index_tuples = set()
for inventory in inventories:
multi_index_tuples.update(
[cls.id_to_tuple(key) for key in inventory.contents.keys()]
Expand Down Expand Up @@ -68,7 +68,7 @@ def to_inventories(self, shell_masses=None):
list
list of radioactivedecay Inventories
"""
comp_dicts = [dict() for i in range(len(self.columns))]
comp_dicts = [{} for i in range(len(self.columns))]
for (atomic_number, mass_number), abundances in self.iterrows():
nuclear_symbol = f"{Z_to_elem(atomic_number)}{mass_number}"
for i in range(len(self.columns)):
Expand Down
4 changes: 1 addition & 3 deletions tardis/model/tests/test_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import os
from pathlib import Path
import pytest
import pandas as pd
import pytest
from astropy import units as u
from numpy.testing import assert_almost_equal, assert_array_almost_equal

Expand Down
7 changes: 2 additions & 5 deletions tardis/model/tests/test_csvy_model.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from pathlib import Path
import numpy as np
import pandas as pd
import numpy.testing as npt
import pandas as pd
import pytest

from astropy import units as u
from tardis.io.configuration.config_reader import Configuration
from tardis.io.atom_data.base import AtomData
from tardis.model import SimulationState
import pytest


@pytest.fixture(
Expand Down
1 change: 0 additions & 1 deletion tardis/model/tests/test_density.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
import pandas as pd
import pytest
from numpy.testing import assert_almost_equal
Expand Down

0 comments on commit 3b3a69b

Please sign in to comment.