Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move more things to pytrees, at least internally #557

Open
jnsbck opened this issue Dec 17, 2024 · 0 comments
Open

Move more things to pytrees, at least internally #557

jnsbck opened this issue Dec 17, 2024 · 0 comments

Comments

@jnsbck
Copy link
Contributor

jnsbck commented Dec 17, 2024

I think we should think about moving more and more things to pytrees, i.e. recordings, externals. This would enable us to use tree_map etc. to operate on them more efficiently. Related to #555.

For example, we can make channels and synapses pytrees. This way, we would not need to add the params to jaxnodes / jaxedges seperately. This also means params / states could be nested, i.e. useful for NODEs as well.

class Mechanism:
    def __init__(self):
        self.params = {"kp1": 1, "p2": 2}
        self.states = {"s1": 1, "s2": 2}
        self.name = "mech_name"

    def update_states(*args, **kwargs):
        return

    def tree_flatten(self):
        return ((self.params, self.states), self.name)

    @classmethod
    def tree_unflatten(cls, aux_data, children):
        obj = cls.__new__(cls)
        obj.params, obj.states = children
        obj.name = aux_data
        return obj

tree_util.register_pytree_node(
    Mechanism,
    Mechanism.tree_flatten,
    Mechanism.tree_unflatten
)

The Modules themselves could also be potentially converted to pytrees.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant