-
Notifications
You must be signed in to change notification settings - Fork 62
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for MultiRNNCell #472
Comments
I will take a closer look at this, but I think we can just use Chain. If we can use a chain, we should just add a note in the manual |
Chains of this type work: Chain(
Recurrence(RNNCell(inputsize => latentsize); return_sequence=true),
x -> stack(x; dims=2),
Recurrence(RNNCell(latentsize => latentsize); return_sequence=true),
x -> stack(x; dims=2)
:
) |
You might not need a stack here I think. Recurrence should be able to take a VectorOfArray input (this was one of the reasons to not stack the outputs by default) |
Yes, I just tested that. I only need the stack at the end to allow a Chain(
Recurrence(RNNCell(inputsize => latentsize); return_sequence=true),
Recurrence(RNNCell(latentsize => latentsize); return_sequence=true),
:
x -> stack(x; dims=2)
) works. |
So just a note in the documentation is good enough for this. |
It would be nice to have the equivalent of MultiRNNCell in Lux
The text was updated successfully, but these errors were encountered: