Skip to content

Commit

Permalink
Merge pull request #554 from ggmarshall/main
Browse files Browse the repository at this point in the history
bugfix for pargen load_data to eval all fields in hit dict, specify dtype when init lh5 objects in evt and support subtables in skm
  • Loading branch information
gipert authored Feb 2, 2024
2 parents 967dd4c + 0b569e9 commit 3399d20
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 33 deletions.
11 changes: 6 additions & 5 deletions src/pygama/evt/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def evaluate_to_first_or_last(
(t0 > outt[evt_ids_ch]) & (limarr), t0, outt[evt_ids_ch]
)

return Array(nda=out)
return Array(nda=out, dtype=type(defv))


def evaluate_to_scalar(
Expand Down Expand Up @@ -271,7 +271,7 @@ def evaluate_to_scalar(
res = res.astype(bool)
out[evt_ids_ch] = out[evt_ids_ch] & res & limarr

return Array(nda=out)
return Array(nda=out, dtype=type(defv))


def evaluate_at_channel(
Expand Down Expand Up @@ -356,7 +356,7 @@ def evaluate_at_channel(

out[evt_ids_ch] = np.where(ch == ch_comp.nda[idx_ch], res, out[evt_ids_ch])

return Array(nda=out)
return Array(nda=out, dtype=type(defv))


def evaluate_at_channel_vov(
Expand Down Expand Up @@ -451,7 +451,7 @@ def evaluate_at_channel_vov(
if ch == chns[0]:
type_name = res.dtype

return VectorOfVectors(ak.values_astype(out, type_name))
return VectorOfVectors(ak.values_astype(out, type_name), dtype=type_name)


def evaluate_to_aoesa(
Expand Down Expand Up @@ -684,5 +684,6 @@ def evaluate_to_vector(
)

return VectorOfVectors(
ak.values_astype(ak.drop_none(ak.nan_to_none(ak.Array(out))), type(defv))
ak.values_astype(ak.drop_none(ak.nan_to_none(ak.Array(out))), type(defv)),
dtype=type(defv),
)
13 changes: 8 additions & 5 deletions src/pygama/math/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import rcParams
import logging

import pygama.math.utils as pgu

log = logging.getLogger(__name__)


def get_hist(data, bins=None, range=None, dx=None, wts=None):
"""return hist, bins, var after binning data
Expand Down Expand Up @@ -361,7 +364,7 @@ def get_fwfm(fraction, hist, bins, var=None, mx=None, dmx=0, bl=0, dbl=0, method
# x_lo
i_0 = bin_lo - int(np.floor(n_slope/2))
if i_0 < 0:
print(f"get_fwfm: fit slopes failed")
log.debug(f"get_fwfm: fit slopes failed")
return 0, 0
i_n = i_0 + n_slope
wts = None if var is None else 1/np.sqrt(var[i_0:i_n]) #fails for any var = 0
Expand All @@ -370,7 +373,7 @@ def get_fwfm(fraction, hist, bins, var=None, mx=None, dmx=0, bl=0, dbl=0, method
try:
(m, b), cov = np.polyfit(bin_centers[i_0:i_n], hist[i_0:i_n], 1, w=wts, cov='unscaled')
except np.linalg.LinAlgError:
print(f"get_fwfm: LinAlgError")
log.debug(f"get_fwfm: LinAlgError")
return 0, 0
x_lo = (val_f-b)/m
#uncertainty
Expand All @@ -380,7 +383,7 @@ def get_fwfm(fraction, hist, bins, var=None, mx=None, dmx=0, bl=0, dbl=0, method
# x_hi
i_0 = bin_hi - int(np.floor(n_slope/2)) + 1
if i_0 == len(hist):
print(f"get_fwfm: fit slopes failed")
log.debug(f"get_fwfm: fit slopes failed")
return 0, 0

i_n = i_0 + n_slope
Expand All @@ -389,11 +392,11 @@ def get_fwfm(fraction, hist, bins, var=None, mx=None, dmx=0, bl=0, dbl=0, method
try:
(m, b), cov = np.polyfit(bin_centers[i_0:i_n], hist[i_0:i_n], 1, w=wts, cov='unscaled')
except np.linalg.LinAlgError:
print(f"get_fwfm: LinAlgError")
log.debug(f"get_fwfm: LinAlgError")
return 0, 0
x_hi = (val_f-b)/m
if x_hi < x_lo:
print(f"get_fwfm: fit slopes produced negative fwfm")
log.debug(f"get_fwfm: fit slopes produced negative fwfm")
return 0, 0

#uncertainty
Expand Down
3 changes: 3 additions & 0 deletions src/pygama/math/peak_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
from iminuit import Minuit, cost
from scipy.optimize import brentq, minimize_scalar
from scipy.stats import crystalball
import logging

import pygama.math.histogram as pgh

log = logging.getLogger(__name__)

limit = np.log(sys.float_info.max)/10
kwd = {"parallel": False, "fastmath": True}

Expand Down
23 changes: 11 additions & 12 deletions src/pygama/pargen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,19 @@ def load_data(
masks = np.array([], dtype=bool)
for tstamp, tfiles in files.items():
table = sto.read(lh5_path, tfiles)[0]

file_df = pd.DataFrame(columns=params)
if tstamp in cal_dict:
cal_dict_ts = cal_dict[tstamp]
else:
cal_dict_ts = cal_dict

for outname, info in cal_dict_ts.items():
outcol = table.eval(info["expression"], info.get("parameters", None))
table.add_column(outname, outcol)

for param in params:
if param in cal_dict_ts:
expression = cal_dict_ts[param]["expression"]
parameters = cal_dict_ts[param].get("parameters", None)
file_df[param] = table.eval(expression, parameters)
else:
file_df[param] = table[param]
file_df[param] = table[param]
file_df["run_timestamp"] = np.full(len(file_df), tstamp, dtype=object)
params.append("run_timestamp")
if threshold is not None:
Expand All @@ -101,13 +102,11 @@ def load_data(

table = sto.read(lh5_path, files)[0]
df = pd.DataFrame(columns=params)
for outname, info in cal_dict.items():
outcol = table.eval(info["expression"], info.get("parameters", None))
table.add_column(outname, outcol)
for param in params:
if param in cal_dict:
expression = cal_dict[param]["expression"]
parameters = cal_dict[param].get("parameters", None)
df[param] = table.eval(expression, parameters)
else:
df[param] = table[param]
df[param] = table[param]
if threshold is not None:
masks = df[cal_energy_param] > threshold
df.drop(np.where(~masks)[0], inplace=True)
Expand Down
26 changes: 15 additions & 11 deletions src/pygama/skm/build_skm.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,23 +139,24 @@ def build_skm(
):
miss_val = eval(miss_val)

fw_fld = tbl_cfg["operations"][op]["forward_field"].split(".")
fw_fld = tbl_cfg["operations"][op]["forward_field"]

# load object if from evt tier
if fw_fld[0] == evt_group:
obj = store.read(f"/{fw_fld[0]}/{fw_fld[1]}", f_dict[fw_fld[0]])[
0
].view_as("ak")
if evt_group in fw_fld.replace(".", "/"):
obj = store.read(
f"/{fw_fld.replace('.','/')}", f_dict[fw_fld.split(".", 1)[0]]
)[0].view_as("ak")

# else collect data from lower tier via tcm_idx
else:
if "tcm_idx" not in tbl_cfg["operations"][op].keys():
raise ValueError(
f"{op} is an sub evt level operation. tcm_idx field must be specified"
)
tcm_idx_fld = tbl_cfg["operations"][op]["tcm_idx"].split(".")
tcm_idx_fld = tbl_cfg["operations"][op]["tcm_idx"]
tcm_idx = store.read(
f"/{tcm_idx_fld[0]}/{tcm_idx_fld[1]}", f_dict[tcm_idx_fld[0]]
f"/{tcm_idx_fld.replace('.','/')}",
f_dict[tcm_idx_fld.split(".")[0]],
)[0].view_as("ak")[:, :multi]

obj = ak.Array([[] for x in range(len(tcm_idx))])
Expand Down Expand Up @@ -184,14 +185,17 @@ def build_skm(
fl_idx = ak.to_numpy(ak.flatten(ch_idx), allow_missing=False)

if (
f"{utils.get_table_name_by_pattern(tcm_id_table_pattern,ch)}/{fw_fld[0]}/{fw_fld[1]}"
not in lh5.ls(f_dict[fw_fld[0]], f"ch{ch}/{fw_fld[0]}/")
f"{utils.get_table_name_by_pattern(tcm_id_table_pattern,ch)}/{fw_fld.replace('.','/')}"
not in lh5.ls(
f_dict[[key for key in f_dict if key in fw_fld][0]],
f"ch{ch}/{fw_fld.rsplit('.',1)[0]}/",
)
):
och = Array(nda=np.full(len(fl_idx), miss_val))
else:
och, _ = store.read(
f"{utils.get_table_name_by_pattern(tcm_id_table_pattern,ch)}/{fw_fld[0]}/{fw_fld[1]}",
f_dict[fw_fld[0]],
f"{utils.get_table_name_by_pattern(tcm_id_table_pattern,ch)}/{fw_fld.replace('.','/')}",
f_dict[[key for key in f_dict if key in fw_fld][0]],
idx=fl_idx,
)
if not isinstance(och, Array):
Expand Down

0 comments on commit 3399d20

Please sign in to comment.