diff --git a/src/pygama/pargen/utils.py b/src/pygama/pargen/utils.py index 91002a007..ec0cab06f 100644 --- a/src/pygama/pargen/utils.py +++ b/src/pygama/pargen/utils.py @@ -9,7 +9,6 @@ from lgdo import lh5 log = logging.getLogger(__name__) -sto = lh5.LH5Store() def convert_to_minuit(pars, func): @@ -35,101 +34,116 @@ def return_nans(input): return m.values, m.errors, np.full((len(m.values), len(m.values)), np.nan) -def get_params(file_params, param_list): - out_params = [] - if isinstance(file_params, dict): - possible_keys = file_params.keys() - elif isinstance(file_params, list): - possible_keys = file_params - for param in param_list: - for key in possible_keys: - if key in param: - out_params.append(key) - return np.unique(out_params).tolist() - - def load_data( - files: list, + files: str | list | dict, lh5_path: str, cal_dict: dict, - params: list, + params: set, cal_energy_param: str = "cuspEmax_ctc_cal", threshold=None, return_selection_mask=False, -) -> tuple(np.array, np.array, np.array, np.array): +) -> pd.DataFrame | tuple(pd.DataFrame, np.array): """ - Loads in the A/E parameters needed and applies calibration constants to energy + Loads parameters from data files. Applies calibration to cal_energy_param + and uses this to apply a lower energy threshold. + + files + file or list of files or dict pointing from timestamps to lists of files + lh5_path + path to table in files + cal_dict + dictionary with operations used to apply calibration constants + params + list of parameters to load from file + cal_energy_param + name of uncalibrated energy parameter + threshold + lower energy threshold for events to load + return_selection_map + if True, return selection mask for threshold along with data """ + params = set(params) if isinstance(files, str): files = [files] if isinstance(files, dict): - keys = lh5.ls( - files[list(files)[0]][0], - lh5_path if lh5_path[-1] == "/" else lh5_path + "/", - ) - keys = [key.split("/")[-1] for key in keys] - if list(files)[0] in cal_dict: - params = get_params(keys + list(cal_dict[list(files)[0]].keys()), params) - else: - params = get_params(keys + list(cal_dict.keys()), params) - + # Go through each tstamp and recursively load_data on file lists df = [] - all_files = [] - masks = np.array([], dtype=bool) + masks = [] for tstamp, tfiles in files.items(): - table = sto.read(lh5_path, tfiles)[0] - - file_df = pd.DataFrame(columns=params) - if tstamp in cal_dict: - cal_dict_ts = cal_dict[tstamp] + file_df = load_data( + tfiles, + lh5_path, + cal_dict.get(tstamp, cal_dict), + params, + cal_energy_param, + threshold, + return_selection_mask, + ) + + if return_selection_mask: + file_df[0]["run_timestamp"] = np.full( + len(file_df[0]), tstamp, dtype=object + ) + df.append(file_df[0]) + masks.append(file_df[1]) else: - cal_dict_ts = cal_dict + file_df["run_timestamp"] = np.full(len(file_df), tstamp, dtype=object) + df.append(file_df) - for outname, info in cal_dict_ts.items(): - outcol = table.eval(info["expression"], info.get("parameters", None)) - table.add_column(outname, outcol) - - for param in params: - file_df[param] = table[param] + df = pd.concat(df) + if return_selection_mask: + masks = np.concatenate(masks) - file_df["run_timestamp"] = np.full(len(file_df), tstamp, dtype=object) + elif isinstance(files, list): + # Get set of available fields between input table and cal_dict + file_keys = lh5.ls( + files[0], lh5_path if lh5_path[-1] == "/" else lh5_path + "/" + ) + file_keys = {key.split("/")[-1] for key in file_keys} - if threshold is not None: - mask = file_df[cal_energy_param] > threshold - file_df.drop(np.where(~mask)[0], inplace=True) - else: - mask = np.ones(len(file_df), dtype=bool) - masks = np.append(masks, mask) - df.append(file_df) - all_files += tfiles + # Get set of keys in calibration expressions that show up in file + cal_keys = { + name + for info in cal_dict.values() + for name in compile(info["expression"], "0vbb is real!", "eval").co_names + } & file_keys - params.append("run_timestamp") - df = pd.concat(df) + # Get set of fields to read from files + fields = cal_keys | (file_keys & params) - elif isinstance(files, list): - keys = lh5.ls(files[0], lh5_path if lh5_path[-1] == "/" else lh5_path + "/") - keys = [key.split("/")[-1] for key in keys] - params = get_params(keys + list(cal_dict.keys()), params) - - table = sto.read(lh5_path, files)[0] - df = pd.DataFrame(columns=params) - for outname, info in cal_dict.items(): - outcol = table.eval(info["expression"], info.get("parameters", None)) - table.add_column(outname, outcol) - for param in params: - df[param] = table[param] + lh5_it = lh5.iterator.LH5Iterator( + files, lh5_path, field_mask=fields, buffer_len=100000 + ) + df_fields = params & (fields | set(cal_dict)) + if df_fields != params: + log.debug( + f"load_data(): params not found in data files or cal_dict: {params-df_fields}" + ) + df = pd.DataFrame(columns=list(df_fields)) + + for table, entry, n_rows in lh5_it: + # Evaluate all provided expressions and add to table + for outname, info in cal_dict.items(): + table[outname] = table.eval( + info["expression"], info.get("parameters", None) + ) + + # Copy params in table into dataframe + for par in df: + # First set of entries: allocate enough memory for all entries + if entry == 0: + df[par] = np.resize(table[par], len(lh5_it)) + else: + df.loc[entry : entry + n_rows - 1, par] = table[par][:n_rows] + + # Evaluate threshold mask and drop events below threshold if threshold is not None: masks = df[cal_energy_param] > threshold df.drop(np.where(~masks)[0], inplace=True) else: masks = np.ones(len(df), dtype=bool) - all_files = files - - for col in list(df.keys()): - if col not in params: - df.drop(col, inplace=True, axis=1) log.debug("data loaded") if return_selection_mask: diff --git a/src/pygama/utils.py b/src/pygama/utils.py index 888ca396c..b6b73d26c 100644 --- a/src/pygama/utils.py +++ b/src/pygama/utils.py @@ -51,7 +51,7 @@ class NumbaPygamaDefaults(MutableMapping): """ def __init__(self) -> None: - self.parallel: bool = getenv_bool("PYGAMA_PARALLEL", default=True) + self.parallel: bool = getenv_bool("PYGAMA_PARALLEL", default=False) self.fastmath: bool = getenv_bool("PYGAMA_FASTMATH", default=True) def __getitem__(self, item: str) -> Any: