Skip to content

Commit

Permalink
Rename register_* to add_feature_*, clean up APFitter var names
Browse files Browse the repository at this point in the history
Also worked on more consistent naming of variables in APFitter
implementation
  • Loading branch information
Dries Van De Putte committed May 17, 2024
1 parent 02cd40b commit 45d1497
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 49 deletions.
66 changes: 35 additions & 31 deletions pahfit/apfitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,20 +60,20 @@ class APFitter(Fitter):
def __init__(self):
"""Construct a new fitter.
After construction, use the register_() functions to start
setting up a model, then call finalize_model().
After construction, use the add_feature_() functions to start
setting up a model, then call finalize().
"""
self.additive_components = []
self.multiplicative_components = []
self.component_types = {}
self.feature_types = {}
self.model = None
self.message = None

def components(self):
"""Return list of component names.
Only works after finalize_model().
Only works after finalize().
"""
if hasattr(self.model, "submodel_names"):
Expand All @@ -82,10 +82,10 @@ def components(self):
# single-component edge case
return [self.model.name]

def finalize_model(self):
def finalize(self):
"""Sum the registered components into one CompoundModel.
To be called after a series of "register_()" calls, and before
To be called after a series of "add_feature_()" calls, and before
using the other functionality (fitting and evaluating).
"""
Expand All @@ -99,8 +99,10 @@ def finalize_model(self):
for c in self.multiplicative_components:
self.model *= c

def _register_component(self, astropy_model_class, multiplicative=False, **kwargs):
"""Register any component in a generic way.
def _add_component(self, astropy_model_class, multiplicative=False, **kwargs):
"""Generically add any feature as an astropy model component.
To be finalized with finalize()
Parameters
----------
Expand All @@ -115,7 +117,7 @@ def _register_component(self, astropy_model_class, multiplicative=False, **kwarg
kwargs : dict
Arguments for the astropy model constructor, including a
unique value for "name". Should be generated with
self._astropy_model_kwargs; the register_() functions show how
self._astropy_model_kwargs; the add_feature_() functions show how
to do this for each type of feature.
"""
Expand All @@ -124,7 +126,7 @@ def _register_component(self, astropy_model_class, multiplicative=False, **kwarg
else:
self.additive_components.append(astropy_model_class(**kwargs))

def register_starlight(self, name, temperature, tau):
def add_feature_starlight(self, name, temperature, tau):
"""Register a BlackBody1D.
Parameters
Expand All @@ -141,76 +143,76 @@ def register_starlight(self, name, temperature, tau):
tau : analogous. Used as amplitude.
"""
self.component_types[name] = "starlight"
self.feature_types[name] = "starlight"
kwargs = self._astropy_model_kwargs(
name, ["temperature", "amplitude"], [temperature, tau]
)
self._register_component(BlackBody1D, **kwargs)
self._add_component(BlackBody1D, **kwargs)

def register_dust_continuum(self, name, temperature, tau):
def add_feature_dust_continuum(self, name, temperature, tau):
"""Register a ModifiedBlackBody1D.
Analogous. Temperature and tau are used as temperature and
amplitude
"""
self.component_types[name] = "dust_continuum"
self.feature_types[name] = "dust_continuum"
kwargs = self._astropy_model_kwargs(
name, ["temperature", "amplitude"], [temperature, tau]
)
self._register_component(ModifiedBlackBody1D, **kwargs)
self._add_component(ModifiedBlackBody1D, **kwargs)

def register_line(self, name, power, wavelength, fwhm):
def add_feature_line(self, name, power, wavelength, fwhm):
"""Register a PowerGaussian1D
Analogous. Uses an implementation of the Gaussian profile, that
directly fits the power based on the internal PAHFIT units.
"""
self.component_types[name] = "line"
self.feature_types[name] = "line"

kwargs = self._astropy_model_kwargs(
name,
["amplitude", "mean", "stddev"],
[power, wavelength, fwhm / 2.355],
[power, wavelength, fwhm / 2.355],
)
self._register_component(Gaussian1D, **kwargs)
self._add_component(Gaussian1D, **kwargs)

def register_dust_feature(self, name, power, wavelength, fwhm):
def add_feature_dust_feature(self, name, power, wavelength, fwhm):
"""Register a PowerDrude1D.
Analogous. Uses an implementation of the Drude profile that
directly fits the power based on the internal PAHFIT units.
"""
self.component_types[name] = "dust_feature"
self.feature_types[name] = "dust_feature"
kwargs = self._astropy_model_kwargs(
name, ["amplitude", "x_0", "fwhm"], [power, wavelength, fwhm]
)
self._register_component(Drude1D, **kwargs)
self._add_component(Drude1D, **kwargs)

def register_attenuation(self, name, tau):
def add_feature_attenuation(self, name, tau):
"""Register the S07 attenuation component.
Analogous. Uses tau as tau_sil for S07_attenuation. Is
multiplicative.
"""
self.component_types[name] = "attenuation"
self.feature_types[name] = "attenuation"
kwargs = self._astropy_model_kwargs(name, ["tau_sil"], [tau])
self._register_component(S07_attenuation, multiplicative=True, **kwargs)
self._add_component(S07_attenuation, multiplicative=True, **kwargs)

def register_absorption(self, name, tau, wavelength, fwhm):
def add_feature_absorption(self, name, tau, wavelength, fwhm):
"""Register an absorbing Drude1D component.
Analogous. Is multiplicative.
"""
self.component_types[name] = "absorption"
self.feature_types[name] = "absorption"
kwargs = self._astropy_model_kwargs(
name, ["tau", "x_0", "fwhm"], [tau, wavelength, fwhm]
)
self._register_component(att_Drude1D, multiplicative=True, **kwargs)
self._add_component(att_Drude1D, multiplicative=True, **kwargs)

def evaluate_model(self, xz):
"""Evaluate internal astropy model with its current parameters.
Expand Down Expand Up @@ -298,7 +300,7 @@ def get_result(self, component_name):
Parameters
----------
component_name : str
One of the names provided to any of the register_*() calls
One of the names provided to any of the add_feature_*() calls
made during setup. See also Fitter.components().
Returns
Expand All @@ -319,7 +321,7 @@ def get_result(self, component_name):
# CompoundModel but normal single-component model.
component = self.model

c_type = self.component_types[component_name]
c_type = self.feature_types[component_name]
if c_type == "starlight" or c_type == "dust_continuum":
return {
"temperature": component.temperature.value,
Expand Down Expand Up @@ -407,6 +409,8 @@ def _astropy_model_kwargs(component_name, param_names, param_values):
# astropy modeling)
kwargs[param_name] = value
kwargs["fixed"][param_name] = is_fixed
kwargs["bounds"][param_name] = [None if np.isinf(x) else x for x in (lo_bound, up_bound)]
kwargs["bounds"][param_name] = [
None if np.isinf(x) else x for x in (lo_bound, up_bound)
]

return kwargs
23 changes: 12 additions & 11 deletions pahfit/fitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,28 @@ class Fitter(ABC):
During the Fitter setup, the initial values, bounds, and "fixed"
flags are passed using one function call for each component, e.g.
register_line(). Once all components have been added, the
finalize_model() function should be called; some subclasses (e.g.
add_feature_line()). Once all components have been added, the
finalize() function should be called; some subclasses (e.g.
APFitter) need to consolidate the registered components to prepare
the model that they manage for fitting. After this, fit() can be
called to apply the model and the astropy fitter to the data. The
results will then be retrievable for one component at a time, by
passing the component name to get_result().
"""

@abstractmethod
def components(self):
"""Return list of features.
Only works after finalize_model(). Will return the names passed
Only works after finalize(). Will return the names passed
using the register functions.
"""
pass

@abstractmethod
def finalize_model(self):
def finalize(self):
"""Process the registered features and prepare for fitting.
The register functions below allow adding individual features.
Expand All @@ -68,7 +69,7 @@ def finalize_model(self):
pass

@abstractmethod
def register_starlight(self, name, temperature, tau):
def add_feature_starlight(self, name, temperature, tau):
"""Register a starlight feature.
The exact representation depends on the implementation, but the
Expand All @@ -92,12 +93,12 @@ def register_starlight(self, name, temperature, tau):
pass

@abstractmethod
def register_dust_continuum(self, name, temperature, tau):
def add_feature_dust_continuum(self, name, temperature, tau):
"""Register a dust continuum feature."""
pass

@abstractmethod
def register_line(self, name, power, wavelength, fwhm):
def add_feature_line(self, name, power, wavelength, fwhm):
"""Register an emission line feature.
Typically a Gaussian profile.
Expand All @@ -106,7 +107,7 @@ def register_line(self, name, power, wavelength, fwhm):
pass

@abstractmethod
def register_dust_feature(self, name, power, wavelength, fwhm):
def add_feature_dust_feature(self, name, power, wavelength, fwhm):
"""Register a dust feature.
Typically a Drude profile.
Expand All @@ -115,7 +116,7 @@ def register_dust_feature(self, name, power, wavelength, fwhm):
pass

@abstractmethod
def register_attenuation(self, name, tau):
def add_feature_attenuation(self, name, tau):
"""Register the S07 attenuation component.
Other types of attenuation might be possible in the future. Is
Expand All @@ -125,7 +126,7 @@ def register_attenuation(self, name, tau):
pass

@abstractmethod
def register_absorption(self, name, tau, wavelength, fwhm):
def add_feature_absorption(self, name, tau, wavelength, fwhm):
"""Register an absorption feature.
Typically a Drude profile. Is multiplicative.
Expand Down Expand Up @@ -182,7 +183,7 @@ def get_result(self, feature_name):
Parameters
----------
component_name : str
One of the names provided to any of the register_() calls
One of the names provided to any of the add_feature_() calls
made during setup. See also Fitter.components().
Returns
Expand Down
14 changes: 7 additions & 7 deletions pahfit/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,12 +878,12 @@ def cleaned(features_tuple3):
name = row["name"]

if kind == "starlight":
fitter.register_starlight(
fitter.add_feature_starlight(
name, cleaned(row["temperature"]), cleaned(row["tau"])
)

elif kind == "dust_continuum":
fitter.register_dust_continuum(
fitter.add_feature_dust_continuum(
name, cleaned(row["temperature"]), cleaned(row["tau"])
)

Expand Down Expand Up @@ -911,23 +911,23 @@ def cleaned(features_tuple3):
else:
fwhm = cleaned(row["fwhm"])

fitter.register_line(
fitter.add_feature_line(
name, cleaned(row["power"]), cleaned(row["wavelength"]), fwhm
)

elif kind == "dust_feature":
fitter.register_dust_feature(
fitter.add_feature_dust_feature(
name,
cleaned(row["power"]),
cleaned(row["wavelength"]),
cleaned(row["fwhm"]),
)

elif kind == "attenuation":
fitter.register_attenuation(name, cleaned(row["tau"]))
fitter.add_feature_attenuation(name, cleaned(row["tau"]))

elif kind == "absorption":
fitter.register_absorption(
fitter.add_feature_absorption(
name,
cleaned(row["tau"]),
cleaned(row["wavelength"]),
Expand All @@ -939,7 +939,7 @@ def cleaned(features_tuple3):
f"Model components of kind {kind} are not implemented!"
)

fitter.finalize_model()
fitter.finalize()
return fitter

@staticmethod
Expand Down

0 comments on commit 45d1497

Please sign in to comment.