-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updated transforms base #456
Updated transforms base #456
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome!
Two things:
- Could you write a more explicit PR description that shows the API change?
- I suggest to add a slightly more elaborate test, see below.
Thanks!
|
||
assert np.allclose( | ||
inverse[0]["param_array_1"], params[0]["param_array_1"] | ||
), f"{transform} forward, inverse failed." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Can we have some more elaborate test that checks whether it plays well with .get_parameters()
?
E.g.
comp = jx.Compartment()
branch = jx.Branch(comp, nseg=2)
cell = jx.Cell(branch, parents=[-1, 0, 0])
cell.branch("all").make_trainable("radius")
cell.branch(2).make_trainable("radius")
cell.branch(1).make_trainable("length")
cell.branch(0).make_trainable("radius")
cell.branch("all").comp("all").make_trainable("axial_resistivity")
cell.make_trainable("capacitance")
transform = ...
params = cell.get_parameters()
transformed_params = transform.forward(params)
recovered_params = transform.inverse(transformed_params)
# ...test that they are still the same, or that transformed_params has no NaN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added it.
I do scale everything by 1/1000 (as axial resistivity is 5000 for which most transforms are not numerically invertible)
Alright, I did add an API description with an example. There might be a few other considerations:
|
I updated the tutorial notebook with the new parameter transform function. I think (at a later point) it would be nice to have a longer section (or separate tutorial) to describe how to use them! Note that the |
* Updated transforms base * More transforms tests and doc * refactored * add tests for plain Arrays not PyTrees (datasets) * Updating some docs * update tutorial 7 notebook with new transforms --------- Co-authored-by: Matthijs <[email protected]>
Transform API Changes
This pull request introduces modifications to the
ParamsTransform
API. Now, transforms must be explicitly specified for each parameter, ensuring more precise control over their application. Additionally, the structure of the transforms (such as PyTrees or dictionaries) must match the structure of the parameters themselves. For further details, please refer to issue #343.Why These Changes?
The primary goal of these changes is to give and enforce manual control of reparameterizations to the User. Previously this was happening automatically, but suboptimal.
Example Usage
Below is an example demonstrating how to apply the updated API with custom transforms to various cell parameters in a
jaxley
model:Key Benefits