Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated transforms base #456

Conversation

manuelgloeckler
Copy link
Contributor

@manuelgloeckler manuelgloeckler commented Oct 23, 2024

Transform API Changes

This pull request introduces modifications to the ParamsTransform API. Now, transforms must be explicitly specified for each parameter, ensuring more precise control over their application. Additionally, the structure of the transforms (such as PyTrees or dictionaries) must match the structure of the parameters themselves. For further details, please refer to issue #343.

Why These Changes?

The primary goal of these changes is to give and enforce manual control of reparameterizations to the User. Previously this was happening automatically, but suboptimal.

Example Usage

Below is an example demonstrating how to apply the updated API with custom transforms to various cell parameters in a jaxley model:

import jaxley as jx
import jaxley.optimize.transforms as jt

# Define a model with compartments and branches
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=2)
cell = jx.Cell(branch, parents=[-1, 0, 0])

# Specify which parameters to make trainable
cell.branch("all").make_trainable("radius")
cell.branch(2).make_trainable("radius")
cell.branch(1).make_trainable("length")
cell.branch(0).make_trainable("radius")
cell.branch("all").comp("all").make_trainable("axial_resistivity")
cell.make_trainable("capacitance")

# Retrieve the parameters
params = cell.get_parameters()

# Define a function to create appropriate transforms for each parameter
def create_transform(name):
    if name == "axial_resistivity":
        # Must be positive; apply Softplus and scale to match initialization
        return jt.ChainTransform([jt.SoftplusTransform(0), jt.AffineTransform(5000, 0)])
    elif name == "length":
        # Apply Softplus and affine transform for the 'length' parameter
        return jt.ChainTransform([jt.SoftplusTransform(0), jt.AffineTransform(10, 0)])
    else:
        # Default to a Softplus transform for other parameters
        return jt.SoftplusTransform(0)

# Apply the transforms to the parameters
transforms = [{k: create_transform(k) for k in param} for param in params]
tf = jt.ParamTransform(transforms)

# Obtain the unconstrained parameters
params_unconstrained = tf.inverse(params)

Key Benefits

  • Explicit Transformations: This change makes it clear which transformation is applied to each parameter, increasing transparency and control.

@manuelgloeckler manuelgloeckler linked an issue Oct 23, 2024 that may be closed by this pull request
Copy link
Contributor

@michaeldeistler michaeldeistler left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome!

Two things:

  1. Could you write a more explicit PR description that shows the API change?
  2. I suggest to add a slightly more elaborate test, see below.

Thanks!

jaxley/optimize/transforms.py Show resolved Hide resolved

assert np.allclose(
inverse[0]["param_array_1"], params[0]["param_array_1"]
), f"{transform} forward, inverse failed."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Can we have some more elaborate test that checks whether it plays well with .get_parameters()?

E.g.

comp = jx.Compartment()
branch = jx.Branch(comp, nseg=2)
cell = jx.Cell(branch, parents=[-1, 0, 0])

cell.branch("all").make_trainable("radius")
cell.branch(2).make_trainable("radius")
cell.branch(1).make_trainable("length")
cell.branch(0).make_trainable("radius")
cell.branch("all").comp("all").make_trainable("axial_resistivity")
cell.make_trainable("capacitance")

transform = ...

params = cell.get_parameters()
transformed_params = transform.forward(params)
recovered_params = transform.inverse(transformed_params)
# ...test that they are still the same, or that transformed_params has no NaN.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it.

I do scale everything by 1/1000 (as axial resistivity is 5000 for which most transforms are not numerically invertible)

@manuelgloeckler
Copy link
Contributor Author

Alright, I did add an API description with an example.

There might be a few other considerations:

  • We could name the transforms by what they are doing and not what (i.e. Sigmoid -> Bound, Softplus -> BoundLower). It would make it easier to use, but also less visible what actually is happening.
  • We could provide a "default reparameterization" i.e. basically a dict with all the names that can come up in jaxley and provide a sensible reparameterization (i.e. basically what the create_transform is doing in the example)

@Matthijspals
Copy link
Contributor

Matthijspals commented Oct 24, 2024

I updated the tutorial notebook with the new parameter transform function. I think (at a later point) it would be nice to have a longer section (or separate tutorial) to describe how to use them!

Note that the jit function in the notebook is commented out at the moment, see Issue #467
(so it probably makes sense to wait with merging until that issue is fixed, and the notebook is updated again)

@michaeldeistler michaeldeistler merged commit d247ad9 into main Oct 24, 2024
1 check passed
@michaeldeistler michaeldeistler deleted the 343-softplus-transform-does-not-properly-rescale-parameters branch October 24, 2024 14:54
michaeldeistler pushed a commit that referenced this pull request Nov 13, 2024
* Updated transforms base

* More transforms tests and doc

* refactored

* add tests for plain Arrays not PyTrees (datasets)

* Updating some docs

* update tutorial 7 notebook with new transforms

---------

Co-authored-by: Matthijs <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

softplus transform does not properly rescale parameters
3 participants