Skip to content

Commit

Permalink
restore column units on cast (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
braingram authored Dec 11, 2023
1 parent bbc9016 commit db64580
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ Bug Fixes
- Fix ``rebuild_fits_rec_dtype`` handling of unsigned integer columns
with shapes [#213]

- Fix unit roundtripping when writing to a datamodel with a table
to a FITS file [#242]

Changes to API
--------------

Expand Down
17 changes: 16 additions & 1 deletion src/stdatamodels/jwst/datamodels/tests/test_fits.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pytest

from stdatamodels.jwst import datamodels
from stdatamodels.jwst.datamodels import ImageModel, JwstDataModel, RampModel
from stdatamodels.jwst.datamodels import ImageModel, JwstDataModel, RampModel, SpecModel


@pytest.fixture
Expand Down Expand Up @@ -91,3 +91,18 @@ def test_resave_duplication_bug(tmp_path):

with fits.open(fn1) as ff1, fits.open(fn2) as ff2:
assert ff1['ASDF'].size == ff2['ASDF'].size


def test_units_roundtrip(tmp_path):
m = SpecModel()
# this next line is required for stdatamodels to cast
# spec_table to a FITS_rec (similar to having data assigned
# to the attribute)
m.spec_table = m.spec_table
m.spec_table.columns['WAVELENGTH'].unit = 'nm'

fn = tmp_path / "test1.fits"
m.save(fn)

m = datamodels.open(fn)
assert m.spec_table.columns['WAVELENGTH'].unit == 'nm'
10 changes: 10 additions & 0 deletions src/stdatamodels/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,20 @@ def _cast(val, schema):
t['shape'] = shape

dtype = ndarray.asdf_datatype_to_numpy_dtype(schema['datatype'])

# save columns in case this is cast back to a fitsrec
if hasattr(val, 'columns'):
cols = val.columns
else:
cols = None
val = util.gentle_asarray(val, dtype, allow_extra_columns=allow_extra_columns)

if dtype.fields is not None:
val = _as_fitsrec(val)
if cols is not None:
for col in cols:
if col.name in val.names and col.unit is not None:
val.columns[col.name].unit = col.unit

if 'ndim' in schema and len(val.shape) != schema['ndim']:
raise ValueError(
Expand Down

0 comments on commit db64580

Please sign in to comment.