-
Notifications
You must be signed in to change notification settings - Fork 206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Map experimental C (actually C++) API for gradient tape #283
base: master
Are you sure you want to change the base?
Conversation
@saudet , is this PR still just a draft or you think it is ready to be reviewed and merged? |
Since it doesn't look like we're going to do anything for this with TF 2.4.x, I think I'll upgrade this PR to 2.5.0-rc1 and then we can merge after its release? I don't think it makes sense to start doing something with the API for 2.4.x. |
93a827e
to
ae74151
Compare
I've finally rebased this on master and upgraded for TF 2.5.0! I've also undone the unreadable reformatting of presets/tensorflow.java, but feel free to redo if necessary. I'd still consider this a WIP, but if it doesn't break any builds, it should be fine to merge and start getting people playing with it, as long as we're ready to maintain an unstable experimental API.... |
As far as I know we're waiting on full support in Tensorflow. The RFC is here and there seems to be work happening gradually, but it's not to the point where we can use it as a full solution for gradients, I don't think. The actual gradients are here, and as you can see there's quite a few missing, and there's currently no registration method or anything like that. |
@dosier FWIW, you may be better off with PyTorch. And the JavaCPP Presets for PyTorch provides full access to that C++ API: |
Cheers! Been hoping to see the GradientTape integrated into the TF Java API for a while, mainly so that I can contribute RL stuff in KotlinDL :D. But I also need to implement a RL model for my studies this block so the PyTorch java wrapper is a pleasant surprise (I love my static types too much). |
BTW, it looks like the author of KotlinDL would be open to integrating PyTorch as well, see pytorch/pytorch#58973 (comment). /cc @zaleslaw |
Yeah, there are a few ways to integrate Torch in Kotlin
Hope that in 2022 KotlinDL will be able to support the training of Torch models via JNI (or via JavaCPP) Good luck, @dosier with your experiments with RL and hope to see you in the future with the running RL models |
@zaleslaw May I ask why you're considering writing JNI code manually? What's missing from JavaCPP? |
@rnett , you are saying that with the custom gradient supported you've added not long ago plus this new internal API, we are still not able to register our own gradients in eager mode? |
I'm saying that for this new API, there's no built-in registries (i.e. global, or Graph/EagerSession based), so we would have to create and manage our own. Once we do that, it would be easy enough to add custom Java-side gradients. I haven't seen confirmation anywhere that some sort of registry and auto-registration is planed, but I would expect it. I'm not sure how python does it, or if it's using this setup at all. |
I've only tested the build on Linux for now, but it should also work on Mac and Windows.
Note that the API has already changed with 2.5.0, so we should probably upgrade to that version before looking at this too closely.