Skip to content
181 changes: 168 additions & 13 deletions exptool/io/psp_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,33 @@
except ImportError:
raise ImportError("You will need to 'pip install pyyaml' to use this reader.")

def _to_python(obj):
"""
Recursively convert numpy types to native Python types.

This helper function ensures that objects containing numpy types (such as numpy scalars or arrays)
are converted to their native Python equivalents. This is particularly important for YAML serialization,
which may not handle numpy types correctly.

Parameters
----------
obj : any
Any Python object, potentially containing numpy types (e.g., numpy scalars, arrays, or nested structures).

Returns
-------
out : any
The input object with all numpy types converted to native Python types.
"""
if isinstance(obj, dict):
return {k: _to_python(v) for k, v in obj.items()}
elif isinstance(obj, list):
return [_to_python(v) for v in obj]
elif hasattr(obj, 'item'): # catches numpy scalars
return obj.item()
else:
return obj


class Input:
"""Input class to adaptively handle OUT. format specifically
Expand Down Expand Up @@ -79,27 +106,76 @@ def __init__(self, filename,comp=None,verbose=0):
# do an initial read of the header
self.primary_header = dict()

# initialise dictionaries
self.comp_map = dict()
# initialise dictionary for headers
self.header = dict()

self._read_primary_header()

self.comp = comp
_comps = list(self.header.keys())


# if a component is defined, retrieve data
if comp != None:
if comp not in _comps:
raise IOError('The specified component does not exist.')

# or check if we are reading all components
if comp == 'all':
self.data = dict()
for c in _comps:
self.data[c] = self._read_component_data(self.filename,
c,
self.header[c]['nbodies'],
int(self.header[c]['data_start']))

else:
self.data = self._read_component_data(self.filename,
self.header[self.comp]['nbodies'],
int(self.header[self.comp]['data_start']))
if comp not in _comps:
raise IOError(f'The specified component, {comp}, does not exist.')

else:
self.data = self._read_component_data(self.filename,
self.comp,
self.header[self.comp]['nbodies'],
int(self.header[self.comp]['data_start']))

# wrapup
self.f.close()

def write(self, filename):
"""
Write the current data to a PSP/OUT. file.

Parameters
----------
filename : str
The output filename to which the data will be written.

Behavior
--------
Writes all components to the specified file if `comp='all'` was used when reading.
Writing of single components is not implemented and will raise an exception.

Exceptions
----------
NotImplementedError
Raised if attempting to write when `comp` is not 'all'.

Example
-------
>>> inp = Input("input.OUT", comp="all")
>>> inp.write("output.OUT")
"""
if self.comp != 'all':
raise NotImplementedError("Writing single components is not implemented yet. Use comp='all'.")

with open(filename, 'wb') as f:
self._write_primary_header(f)

# Now write all component headers and data sequentially.
for comp in self.header:
self._write_component_header(f, self.header[comp])
self._write_component_data(f, comp, self.data[comp])


def _read_primary_header(self):
"""read the primary header from an OUT. file"""

Expand All @@ -117,6 +193,25 @@ def _read_primary_header(self):
next_comp = self._read_out_component_header()
data_start = next_comp

def _write_primary_header(self, f):
# time is always <f8
np.array([self.time], dtype='<f8').tofile(f)

# total nbodies and number of components
total = sum(self.header[c]['nbodies'] for c in self.header)
ncomp = len(self.header)
np.array([total, ncomp], dtype=np.uint32).tofile(f)

# next the component headers must follow immediately
# magic number goes at byte 16
# TODO: Writing the magic number is required for full PSP file compatibility,
# but this is not yet implemented. Uncomment and verify the following lines
# when ready to support the magic number in the output file format.
#f.seek(16)
#magic = 2915019716 if self._float_len == 4 else 0
#np.array([magic], dtype=np.uint32).tofile(f)
# double up the magic number for consistency
#np.array([magic], dtype=np.uint32).tofile(f)
def _summarise_primary_header(self):
"""a short summary of what is in the file"""

Expand All @@ -133,6 +228,7 @@ def _read_out_component_header(self):

#_ = f.tell() # byte position of this component


if self._float_len == 4:
_1,_2, nbodies, nint_attr, nfloat_attr, infostringlen = np.fromfile(self.f, dtype=np.uint32, count=6)
else:
Expand Down Expand Up @@ -161,6 +257,8 @@ def _read_out_component_header(self):
head_dict['nbodies'] = nbodies
head_dict['data_start'] = comp_data_pos
head_dict['data_end'] = comp_data_end
head_dict['info_len'] = infostringlen
head_dict['info_str'] = head_normal

self.header[head_dict['name']] = head_dict

Expand All @@ -169,9 +267,39 @@ def _read_out_component_header(self):
self.indexing = head_dict['parameters']['indexing']
except:
self.indexing = head_dict['indexing']=='true'
head_dict['parameters']['indexing'] = self.indexing

return comp_data_end

def _write_component_header(self, f, compdict):
"""
compdict is one of self.header[name]
"""

nbodies = _to_python(compdict['nbodies'])
nint_attr = _to_python(compdict['nint_attr'])
nfloat_attr = _to_python(compdict['nfloat_attr'])

compdict_clean = _to_python(compdict)

info = yaml.safe_dump(compdict_clean)
info_bytes = info.encode()
info_len = len(info_bytes)

if self._float_len == 4:
# the 6 numbers at the start: magic number + 5 (so only write 5)
magic = 2915019716
arr = np.array([magic, 0, nbodies,
nint_attr, nfloat_attr, info_len],
dtype=np.uint32)
else:
arr = np.array([nbodies, nint_attr,
nfloat_attr, info_len],
dtype=np.uint32)

arr.tofile(f)
f.write(info_bytes)


def _check_magic_number(self):
"""check the magic number to see if a file is float or double"""
Expand All @@ -189,26 +317,26 @@ def _check_magic_number(self):



def _read_component_data(self,filename,nbodies,offset):
def _read_component_data(self,filename,comp,nbodies,offset):
"""read in all data for component"""

dtype_str = []
colnames = []
if self.header[self.comp]['parameters']['indexing']:
if self.header[comp]['parameters']['indexing']:
# if indexing is on, the 0th column is Long
dtype_str = dtype_str + ['l']
colnames = colnames + ['id']

dtype_str = dtype_str + [self._float_str] * 8
colnames = colnames + ['m', 'x', 'y', 'z', 'vx', 'vy', 'vz', 'potE']

dtype_str = dtype_str + ['i'] * self.header[self.comp]['nint_attr']
dtype_str = dtype_str + ['i'] * self.header[comp]['nint_attr']
colnames = colnames + ['i_attr{}'.format(i)
for i in range(self.header[self.comp]['nint_attr'])]
for i in range(self.header[comp]['nint_attr'])]

dtype_str = dtype_str + [self._float_str] * self.header[self.comp]['nfloat_attr']
dtype_str = dtype_str + [self._float_str] * self.header[comp]['nfloat_attr']
colnames = colnames + ['f_attr{}'.format(i)
for i in range(self.header[self.comp]['nfloat_attr'])]
for i in range(self.header[comp]['nfloat_attr'])]

dtype = np.dtype(','.join(dtype_str))

Expand All @@ -228,3 +356,30 @@ def _read_component_data(self,filename,nbodies,offset):
del out # close the memmap instance

return tbl

def _write_component_data(self, f, comp, data):
compdict = self.header[comp]

dtype_str = []
colnames = []

if compdict['parameters']['indexing']:
dtype_str.append('l')
colnames.append('id')

dtype_str += [self._float_str] * 8
colnames += ['m','x','y','z','vx','vy','vz','potE']

dtype_str += ['i'] * compdict['nint_attr']
colnames += [f'i_attr{i}' for i in range(compdict['nint_attr'])]

dtype_str += [self._float_str] * compdict['nfloat_attr']
colnames += [f'f_attr{i}' for i in range(compdict['nfloat_attr'])]

dtype = np.dtype(','.join(dtype_str))

outarr = np.zeros(compdict['nbodies'], dtype=dtype)
for i, key in enumerate(colnames):
outarr[f'f{i}'] = data[key]

outarr.tofile(f)