Skip to content

Commit

Permalink
Merge pull request #553 from ggmarshall/main
Browse files Browse the repository at this point in the history
pargen updates for new pydataobj version
  • Loading branch information
gipert authored Feb 1, 2024
2 parents 79d47bd + 2dee12b commit a0a4400
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/pygama/math/histogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def get_fwfm(fraction, hist, bins, var=None, mx=None, dmx=0, bl=0, dbl=0, method
# interpolate between the two bins that cross the [fraction] line
# works well for high stats
if bin_lo < 1 or bin_hi >= len(hist)-1:
print(f"get_fwhm: can't interpolate ({bin_lo}, {bin_hi})")
log.debug(f"get_fwhm: can't interpolate ({bin_lo}, {bin_hi})")
return 0, 0

val_f = bl + fraction*(mx-bl)
Expand Down Expand Up @@ -403,7 +403,7 @@ def get_fwfm(fraction, hist, bins, var=None, mx=None, dmx=0, bl=0, dbl=0, method
return x_hi - x_lo, np.sqrt(dxl2 + dxh2)

else:
print(f"get_fwhm: unrecognized method {method}")
log.debug(f"get_fwhm: unrecognized method {method}")
return 0, 0


Expand Down
5 changes: 3 additions & 2 deletions src/pygama/pargen/cuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import lgdo.lh5 as lh5
import numpy as np
import pandas as pd
from lgdo.types import Table
from scipy import stats

import pygama.math.histogram as pgh
Expand Down Expand Up @@ -51,7 +52,7 @@ def generate_cuts(
output_dict = {}
if isinstance(data, pd.DataFrame):
pass
elif isinstance(data, lh5.Table):
elif isinstance(data, Table):
data = {entry: data[entry].nda for entry in get_keys(data, parameters)}
data = pd.DataFrame.from_dict(data)
elif isinstance(data, dict):
Expand Down Expand Up @@ -204,7 +205,7 @@ def get_cut_indexes(
keys = cut_dict.keys()
if isinstance(all_data, pd.DataFrame):
pass
elif isinstance(all_data, lh5.Table):
elif isinstance(all_data, Table):
cut_keys = list(cut_dict)
cut_keys.append(energy_param)
all_data = {
Expand Down
12 changes: 9 additions & 3 deletions src/pygama/pargen/energy_optimisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
from matplotlib.colors import LogNorm
from scipy.optimize import curve_fit, minimize
from scipy.stats import chisquare, norm
from sklearn.exceptions import ConvergenceWarning
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF, ConstantKernel
from sklearn.utils._testing import ignore_warnings

import pygama.math.histogram as pgh
import pygama.math.peak_fitting as pgf
Expand Down Expand Up @@ -922,8 +924,9 @@ def event_selection(
if not isinstance(kev_widths, list):
kev_widths = [kev_widths]

sto = lh5.LH5Store()
df = lh5.load_dfs(raw_files, ["daqenergy", "timestamp"], lh5_path)
df = sto.read(lh5_path, raw_files, field_mask=["daqenergy", "timestamp"])[
0
].view_as("pd")

if pulser_mask is None:
pulser_props = cts.find_pulser_properties(df, energy="daqenergy")
Expand Down Expand Up @@ -1067,7 +1070,7 @@ def event_selection(
log.warning("Less than half number of specified events found")
elif len(peak_ids[final_mask]) < 0.1 * n_events:
log.error("Less than 10% number of specified events found")
out_events = np.unique(np.array(out_events).flatten())
out_events = np.unique(np.concatenate(out_events))
sort_index = np.argsort(np.concatenate(final_events))
idx_list = get_wf_indexes(sort_index, [len(mask) for mask in final_events])
return out_events, idx_list
Expand Down Expand Up @@ -1381,6 +1384,7 @@ def get_first_point(self):
self.optimal_ei = None
return self.optimal_x, self.optimal_ei

@ignore_warnings(category=ConvergenceWarning)
def iterate_values(self):
nan_idxs = np.isnan(self.y_init)
self.gauss_pr.fit(self.x_init[~nan_idxs], np.array(self.y_init)[~nan_idxs])
Expand Down Expand Up @@ -1451,6 +1455,7 @@ def get_best_vals(self):
out_dict[name][parameter] = value_str
return out_dict

@ignore_warnings(category=ConvergenceWarning)
def plot(self, init_samples=None):
nan_idxs = np.isnan(self.y_init)
fail_idxs = np.isnan(self.yerr_init)
Expand Down Expand Up @@ -1557,6 +1562,7 @@ def plot(self, init_samples=None):
plt.close()
return fig

@ignore_warnings(category=ConvergenceWarning)
def plot_acq(self, init_samples=None):
nan_idxs = np.isnan(self.y_init)
self.gauss_pr.fit(self.x_init[~nan_idxs], np.array(self.y_init)[~nan_idxs])
Expand Down
10 changes: 6 additions & 4 deletions src/pygama/pargen/extract_tau.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import pygama.pargen.energy_optimisation as om

log = logging.getLogger(__name__)
sto = lh5.LH5Store()


def load_data(
Expand All @@ -36,8 +37,9 @@ def load_data(
threshold: int = 5000,
wf_field: str = "waveform",
) -> lgdo.Table:
sto = lh5.LH5Store()
df = lh5.load_dfs(raw_file, ["daqenergy", "timestamp"], lh5_path)
df = sto.read(lh5_path, raw_file, field_mask=["daqenergy", "timestamp"])[0].view_as(
"pd"
)

if pulser_mask is None:
pulser_props = cts.find_pulser_properties(df, energy="daqenergy")
Expand Down Expand Up @@ -142,8 +144,8 @@ def get_decay_constant(
)
axins.axvline(high_bin, color="red")
axins.set_xlim(bins[in_min], bins[in_max])
labels = ax.get_xticklabels()
ax.set_xticklabels(labels=labels, rotation=45)
ax.set_xticks(ax.get_xticks())
ax.set_xticklabels(labels=ax.get_xticklabels(), rotation=45)
out_plot_dict["slope"] = fig
if display > 1:
plt.show()
Expand Down
28 changes: 18 additions & 10 deletions src/pygama/pargen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from lgdo import Table, lh5

log = logging.getLogger(__name__)
sto = lh5.LH5Store()


def return_nans(input):
Expand Down Expand Up @@ -50,8 +51,6 @@ def load_data(
Loads in the A/E parameters needed and applies calibration constants to energy
"""

sto = lh5.LH5Store()

out_df = pd.DataFrame(columns=params)

if isinstance(files, dict):
Expand Down Expand Up @@ -142,14 +141,23 @@ def get_tcm_pulser_ids(tcm_file, channel, multiplicity_threshold):
mask = np.append(mask, file_mask)
ids = np.where(mask)[0]
else:
data = lh5.load_dfs(tcm_file, ["array_id", "array_idx"], "hardware_tcm_1")
cum_length = lh5.load_nda(tcm_file, ["cumulative_length"], "hardware_tcm_1")[
"cumulative_length"
]
cum_length = np.append(np.array([0]), cum_length)
n_channels = np.diff(cum_length)
evt_numbers = np.repeat(np.arange(0, len(cum_length) - 1), np.diff(cum_length))
evt_mult = np.repeat(np.diff(cum_length), np.diff(cum_length))
data = pd.DataFrame(
{
"array_id": sto.read("hardware_tcm_1/array_id", tcm_file)[0].view_as(
"np"
),
"array_idx": sto.read("hardware_tcm_1/array_idx", tcm_file)[0].view_as(
"np"
),
}
)
cumulength = sto.read("hardware_tcm_1/cumulative_length", tcm_file)[0].view_as(
"np"
)
cumulength = np.append(np.array([0]), cumulength)
n_channels = np.diff(cumulength)
evt_numbers = np.repeat(np.arange(0, len(cumulength) - 1), np.diff(cumulength))
evt_mult = np.repeat(np.diff(cumulength), np.diff(cumulength))
data["evt_number"] = evt_numbers
data["evt_mult"] = evt_mult
high_mult_events = np.where(n_channels > multiplicity_threshold)[0]
Expand Down

0 comments on commit a0a4400

Please sign in to comment.