Skip to content

Commit

Permalink
style: pre-commit fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
pre-commit-ci[bot] committed Feb 2, 2024
1 parent e6129af commit 96a07cc
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
3 changes: 2 additions & 1 deletion src/pygama/evt/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,5 +684,6 @@ def evaluate_to_vector(
)

return VectorOfVectors(
ak.values_astype(ak.drop_none(ak.nan_to_none(ak.Array(out))), type(defv)),dtype=type(defv)
ak.values_astype(ak.drop_none(ak.nan_to_none(ak.Array(out))), type(defv)),
dtype=type(defv),
)
10 changes: 3 additions & 7 deletions src/pygama/pargen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,15 @@ def load_data(
masks = np.array([], dtype=bool)
for tstamp, tfiles in files.items():
table = sto.read(lh5_path, tfiles)[0]

file_df = pd.DataFrame(columns=params)
if tstamp in cal_dict:
cal_dict_ts = cal_dict[tstamp]
else:
cal_dict_ts = cal_dict

for outname, info in cal_dict_ts.items():
outcol = table.eval(
info["expression"], info.get("parameters", None)
)
outcol = table.eval(info["expression"], info.get("parameters", None))
table.add_column(outname, outcol)

for param in params:
Expand All @@ -105,9 +103,7 @@ def load_data(
table = sto.read(lh5_path, files)[0]
df = pd.DataFrame(columns=params)
for outname, info in cal_dict.items():
outcol = table.eval(
info["expression"], info.get("parameters", None)
)
outcol = table.eval(info["expression"], info.get("parameters", None))
table.add_column(outname, outcol)
for param in params:
df[param] = table[param]
Expand Down
18 changes: 11 additions & 7 deletions src/pygama/skm/build_skm.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,10 @@ def build_skm(
fw_fld = tbl_cfg["operations"][op]["forward_field"]

# load object if from evt tier
if evt_group in fw_fld.replace('.','/'):
obj = store.read(f"/{fw_fld.replace('.','/')}", f_dict[fw_fld.split(".",1)[0]])[
0
].view_as("ak")
if evt_group in fw_fld.replace(".", "/"):
obj = store.read(
f"/{fw_fld.replace('.','/')}", f_dict[fw_fld.split(".", 1)[0]]
)[0].view_as("ak")

# else collect data from lower tier via tcm_idx
else:
Expand All @@ -155,7 +155,8 @@ def build_skm(
)
tcm_idx_fld = tbl_cfg["operations"][op]["tcm_idx"]
tcm_idx = store.read(
f"/{tcm_idx_fld.replace('.','/')}", f_dict[tcm_idx_fld.split(".")[0]]
f"/{tcm_idx_fld.replace('.','/')}",
f_dict[tcm_idx_fld.split(".")[0]],
)[0].view_as("ak")[:, :multi]

obj = ak.Array([[] for x in range(len(tcm_idx))])
Expand All @@ -182,10 +183,13 @@ def build_skm(
ch_idx = idx[ids == ch]
ct_idx = ak.count(ch_idx, axis=-1)
fl_idx = ak.to_numpy(ak.flatten(ch_idx), allow_missing=False)

if (
f"{utils.get_table_name_by_pattern(tcm_id_table_pattern,ch)}/{fw_fld.replace('.','/')}"
not in lh5.ls(f_dict[[key for key in f_dict if key in fw_fld][0]], f"ch{ch}/{fw_fld.rsplit('.',1)[0]}/")
not in lh5.ls(
f_dict[[key for key in f_dict if key in fw_fld][0]],
f"ch{ch}/{fw_fld.rsplit('.',1)[0]}/",
)
):
och = Array(nda=np.full(len(fl_idx), miss_val))
else:
Expand Down

0 comments on commit 96a07cc

Please sign in to comment.