diff --git a/pahfit/features/features.py b/pahfit/features/features.py index db7d300..4bb56f7 100644 --- a/pahfit/features/features.py +++ b/pahfit/features/features.py @@ -23,7 +23,7 @@ from astropy.io.misc.yaml import yaml from importlib import resources from pahfit.errors import PAHFITFeatureError -from pahfit.features.features_format import BoundedMaskedColumn, BoundedParTableFormatter +from pahfit.features.features_format import BoundedParTableFormatter import pahfit.units # Feature kinds and associated parameters @@ -86,8 +86,8 @@ def value_bounds(val, bounds): Returns: ------- - The value, if unbounded, or a 3 element tuple (value, min, max). - Any missing bound is replaced with the numpy `masked' value. + A 3 element tuple (value, min, max). + Any missing bound is replaced with the numpy.nan value. Raises: ------- @@ -99,7 +99,7 @@ def value_bounds(val, bounds): if val is None: val = np.ma.masked if not bounds: - return (val,) + 2 * (np.ma.masked,) # Fixed + return (val,) + 2 * (np.nan,) # (val,nan,nan) indicates fixed ret = [val] for i, b in enumerate(bounds): if isinstance(b, str): @@ -132,12 +132,12 @@ class Features(Table): """ TableFormatter = BoundedParTableFormatter - MaskedColumn = BoundedMaskedColumn param_covar = TableAttribute(default=[]) - _param_attrs = set(('value', 'bounds')) # params can have these attributes _group_attrs = set(('bounds', 'features', 'kind')) # group-level attributes - _no_bounds = set(('name', 'group', 'geometry', 'model')) # String attributes (no bounds) + _param_attrs = set(('value', 'bounds')) # Each parameter can have these attributes + _no_bounds = set(('name', 'group', 'kind', 'geometry', 'model')) # str attributes (no bounds) + _bounds_dtype = np.dtype([("val", "f4"), ("min", "f4"), ("max", "f4")]) # bounded param type @classmethod def read(cls, file, *args, **kwargs): @@ -332,11 +332,9 @@ def _construct_table(cls, inp: dict): else: params[missing] = value_bounds(0.0, bounds=(0.0, None)) rows.append(dict(name=name, **params)) - table_columns = rows[0].keys() - t = cls(rows, names=table_columns) - for p in KIND_PARAMS[kind]: - if p not in cls._no_bounds: - t[p].info.format = "0.4g" # Nice format (customized by Formatter) + param_names = rows[0].keys() + dtypes = [str if x in cls._no_bounds else cls._bounds_dtype for x in param_names] + t = cls(rows, names=param_names, dtype=dtypes) tables.append(t) tables = vstack(tables) for cn, col in tables.columns.items(): @@ -376,8 +374,12 @@ def mask_feature(self, name, mask_value=True): pass else: # mask only the value, not the bounds - row[col_name].mask[0] = mask_value + row[col_name].mask['val'] = mask_value def unmask_feature(self, name): """Remove the mask for all parameters of a feature.""" self.mask_feature(name, mask_value=False) + + def _base_repr_(self, *args, **kwargs): + """Omit dtype on self-print.""" + return super()._base_repr_(*args, ** kwargs | dict(show_dtype=False)) diff --git a/pahfit/features/features_format.py b/pahfit/features/features_format.py index 7836972..2871e93 100644 --- a/pahfit/features/features_format.py +++ b/pahfit/features/features_format.py @@ -1,53 +1,45 @@ -import numpy.ma as ma -from astropy.table import MaskedColumn +import numpy as np from astropy.table.pprint import TableFormatter # * Special table formatting for bounded (val, min, max) values -def fmt_func(fmt): - def _fmt(v): - if ma.is_masked(v[0]): - return " " - if ma.is_masked(v[1]): - return f"{v[0]:{fmt}} (Fixed)" - return f"{v[0]:{fmt}} ({v[1]:{fmt}}, {v[2]:{fmt}})" - +def fmt_func(fmt: str): + """Format bounded variables specially.""" + if fmt.startswith('%'): + fmt = fmt[1:] + + def _fmt(x): + ret = f"{x['val']:{fmt}}" + if np.isnan(x['min']) and np.isnan(x['max']): + return ret + " (fixed)" + else: + mn = ("-∞" if np.isnan(x['min']) or x['min'] == -np.inf + else f"{x['min']:{fmt}}") + mx = ("∞" if np.isnan(x['max']) or x['max'] == np.inf + else f"{x['max']:{fmt}}") + return f"{ret} ({mn}, {mx})" return _fmt -class BoundedMaskedColumn(MaskedColumn): - """Masked column which can be toggled to group rows into one item - for formatting. To be set as Table's `MaskedColumn'. - """ - - _omit_shape = False - - @property - def shape(self): - sh = super().shape - return sh[0:-1] if self._omit_shape and len(sh) > 1 else sh - - def is_fixed(self): - return ma.getmask(self)[:, 1:].all(1) - - class BoundedParTableFormatter(TableFormatter): """Format bounded parameters. Bounded parameters are 3-field structured arrays, with fields - 'var', 'min', and 'max'. To be set as Table's `TableFormatter'. + 'val', 'min', and 'max'. To be set as Table's `TableFormatter'. """ - def _pformat_table(self, table, *args, **kwargs): bpcols = [] + tlfmt = table.meta.get('pahfit_format') try: - colsh = [(col, col.shape) for col in table.columns.values()] - BoundedMaskedColumn._omit_shape = True - for col, sh in colsh: - if len(sh) == 2 and sh[1] == 3: + for col in table.columns.values(): + if len(col.dtype) == 3: # bounded! bpcols.append((col, col.info.format)) - col.info.format = fmt_func(col.info.format or "g") + fmt = col.meta.get('pahfit_format') or tlfmt or "g" + col.info.format = fmt_func(fmt) return super()._pformat_table(table, *args, **kwargs) finally: - BoundedMaskedColumn._omit_shape = False for col, fmt in bpcols: col.info.format = fmt + + def _name_and_structure(self, name, *args): + "Simplified column name: no val, min, max needed." + return name diff --git a/pahfit/features/util.py b/pahfit/features/util.py index 68513aa..d9493f9 100644 --- a/pahfit/features/util.py +++ b/pahfit/features/util.py @@ -1,29 +1,28 @@ """pahfit.util General pahfit.features utility functions.""" import numpy as np -import numpy.ma as ma def bounded_is_missing(val): """Return a mask array indicating which of the bounded values are missing. A missing bounded value has a masked value.""" - return ma.getmask(val)[..., 0] + return getattr(val['val'], 'mask', None) or np.zeros_like(val['val'], dtype=bool) def bounded_is_fixed(val): """Return a mask array indicating which of the bounded values are fixed. A fixed bounded value has masked bounds.""" - return ma.getmask(val)[..., -2:].all(-1) + return np.isnan(val['min']) & np.isnan(val['max']) def bounded_min(val): """Return the minimum of each bounded value passed. Either the lower bound, or, if no such bound is set, the value itself.""" - lower = val[..., 1] - return np.where(lower, lower, val[..., 0]) + lower = val['min'] + return np.where(lower, lower, val['val']) def bounded_max(val): """Return the maximum of each bounded value passed. Either the upper bound, or, if no such bound is set, the value itself.""" - upper = val[..., 2] - return np.where(upper, upper, val[..., 0]) + upper = val['max'] + return np.where(upper, upper, val['val'])