Skip to content

Commit

Permalink
[#159] updating unit tests for add_bins()
Browse files Browse the repository at this point in the history
  • Loading branch information
dvezinet committed Nov 15, 2024
1 parent 896f29a commit 6151bd1
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 65 deletions.
61 changes: 6 additions & 55 deletions datastock/tests/test_01_DataStock.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,63 +350,14 @@ def test08_domain_ref(self):
lk = list(domain.keys())
assert all([isinstance(dout[k0]['ind'], np.ndarray) for k0 in lk])

def test09_binning(self):

bins = np.linspace(1, 5, 8)
lk = [
('y', 'nx', bins, 0, False, False, 'y_bin0'),
('y', 'nx', bins, 0, True, False, 'y_bin1'),
('y', 'nx', 'x', 0, False, True, 'y_bin2'),
('y', 'nx', 'x', 0, True, True, 'y_bin3'),
('prof0', 'x', 'nt0', 1, False, True, 'p0_bin0'),
('prof0', 'x', 'nt0', 1, True, True, 'p0_bin1'),
('prof0-bis', 'prof0', 'x', [0, 1], False, True, 'p1_bin0'),
]

for ii, (k0, kr, kb, ax, integ, store, kbin) in enumerate(lk):
dout = self.st.binning(
data=k0,
bin_data0=kr,
bins0=kb,
axis=ax,
integrate=integ,
store=store,
store_keys=kbin,
safety_ratio=0.95,
returnas=True,
)

if np.isscalar(ax):
ax = [ax]

if isinstance(kb, str):
if kb in self.st.ddata:
nb = self.st.ddata[kb]['data'].size
else:
nb = self.st.dref[kb]['size']
else:
nb = bins.size

k0 = list(dout.keys())[0]
shape = [
ss for ii, ss in enumerate(self.st.ddata[k0]['data'].shape)
if ii not in ax
]

shape.insert(ax[0], nb)
if dout[k0]['data'].shape != tuple(shape):
shstr = dout[k0]['data'].shape
msg = (
"Mismatching shapes for case {ii}!\n"
f"\t- dout['{k0}']['data'].shape = {shstr}\n"
f"\t- expected: {tuple(shape)}"
)
raise Exception(msg)

def test10_add_bins(self):
def test09_add_bins(self):
_input.add_bins(self.coll)

def test10_interpolate(self):
def test10_binning(self):
# _input.binning(self.coll)
pass

def test11_interpolate(self):

lk = ['y', 'y', 'prof0', 'prof0', 'prof0', '3d']
lref = [None, 'nx', 't0', ['nt0', 'nx'], ['t0', 'x'], ['t0', 'x']]
Expand Down
78 changes: 68 additions & 10 deletions datastock/tests/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,25 +23,83 @@ def add_bins(coll):
# -------------------------

# linear uniform 1d
coll.add_bins('bin0', edges=np.linspace(0, 1, 10), units='m')
coll.add_bins('b1d_lin', edges=np.linspace(0, 1, 10), units='m')

# log uniform 1d
coll.add_bins(edges=np.logspace(0, 1, 10), units='eV')
coll.add_bins('b1d_log', edges=np.logspace(0, 1, 10), units='eV')

# non-uniform 1d
coll.add_bins(edges=np.r_[1, 2, 5, 10, 12, 20], units='s')
coll.add_bins('b2d_rand', edges=np.r_[1, 2, 5, 10, 12, 20], units='s')

# linear uniform 2d
coll.add_bins('bin0', edges=np.linspace(0, 1, 10), units='m')

# log uniform 2d
coll.add_bins(edges=np.logspace(0, 1, 10), units='eV')

# non-uniform 2d
coll.add_bins(edges=np.r_[1, 2, 5, 10, 12, 20], units='s')
coll.add_bins(
'b2d_lin',
edges=(np.linspace(0, 1, 10), np.linspace(0, 3, 20)),
units='m',
)

# log uniform mix 2d
coll.add_bins(
'b2d_mix',
edges=(np.logspace(0, 1, 10), np.pi*np.r_[0, 0.5, 1, 1.2, 1.5, 2]),
units=('eV', 'rad'),
)

# -------------------------
# define bins pre-existing
# -------------------------

return


def binning(coll):
bins = np.linspace(1, 5, 8)
lk = [
('y', 'nx', bins, 0, False, False, 'y_bin0'),
('y', 'nx', bins, 0, True, False, 'y_bin1'),
('y', 'nx', 'x', 0, False, True, 'y_bin2'),
('y', 'nx', 'x', 0, True, True, 'y_bin3'),
('prof0', 'x', 'nt0', 1, False, True, 'p0_bin0'),
('prof0', 'x', 'nt0', 1, True, True, 'p0_bin1'),
('prof0-bis', 'prof0', 'x', [0, 1], False, True, 'p1_bin0'),
]

for ii, (k0, kr, kb, ax, integ, store, kbin) in enumerate(lk):
dout = coll.binning(
data=k0,
bin_data0=kr,
bins0=kb,
axis=ax,
integrate=integ,
store=store,
store_keys=kbin,
safety_ratio=0.95,
returnas=True,
)

if np.isscalar(ax):
ax = [ax]

if isinstance(kb, str):
if kb in coll.ddata:
nb = coll.ddata[kb]['data'].size
else:
nb = coll.dref[kb]['size']
else:
nb = bins.size

k0 = list(dout.keys())[0]
shape = [
ss for ii, ss in enumerate(coll.ddata[k0]['data'].shape)
if ii not in ax
]

shape.insert(ax[0], nb)
if dout[k0]['data'].shape != tuple(shape):
shstr = dout[k0]['data'].shape
msg = (
"Mismatching shapes for case {ii}!\n"
f"\t- dout['{k0}']['data'].shape = {shstr}\n"
f"\t- expected: {tuple(shape)}"
)
raise Exception(msg)

0 comments on commit 6151bd1

Please sign in to comment.