-
Notifications
You must be signed in to change notification settings - Fork 206
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Graph custom gradient support #292
Conversation
We probably need to "define" a wrapper class for the |
@rnett looks very cool, especially your test with registered grads for the Concat function. Did you take grad formulas from C++ code or from Python? |
I don't have any specific plans, but I think it would be good to add missing grads in the Unfortunately writing gradients is rather hard atm since you don't have good ways to access attributes or inputs for Ops (like say
I'm not taking grad formulas from anywhere yet, this method adds the gradient to those in |
Yeah I know, but Imo it's cleaner than just relying on different packages. There's a number of methods (GradientHelpers mostly) that need both.
Yeah, it's graph only, it's the legacy graph gradients. But until the graph backend starts using the gradient tape API, it's the only way to add gradients for graphs. |
cc @saudet Hmm, ok, I still get errors with a vector adapter: (
Making the adapter type a pointer doesn't help either. |
Adapters don't work for defining function types like that, whether it is |
Had to remove |
Sounds like memory corruption. Something is probably getting deallocated too early. We can set the "org.bytedeco.javacpp.nopointergc" system property to "true" and see that way if it's not GC doing that. |
Yeah, my function pointers were getting GC'd, I forgot to save them. |
Next question: I'm generating a wrapper for |
Right, the way that works is by mapping a minimalist set of functions that are usually available in these kinds of templates. There should be some other way to erase an element though, there isn't? In any case, let me figure out some way to customize the output of that a bit... |
Doesn't seem like it, if there is it's not coming up on google |
Ok, I'm confident enough that pretty much all "map" containers have an
BTW, it looks like you're starting to map all of the legacy C++ API. You could pick up from what has already been done for TF 1.x: |
Thanks, that works nicely.
I'm trying not to, but it's gotten pretty big, that should help. |
04aba60
to
59750dc
Compare
Now I'm getting a segfault from |
If those |
Yeah, that's what I thought it could be, but I didn't change anything around those methods, and the gradient methods all use PointerScopes. Also, the failing methods all verify the pointer is not null before calling. I cherry-picked the JavaCPP generation changes ( |
Another note: only |
The method looks like this: try (PointerScope scope = new PointerScope()) {
TF_Output output = new TF_Output().oper(getUnsafeNativeHandle()).index(index);
return TF_OperationOutputNumConsumers(output);
} If I breakpoint right before the |
Ok, I've found the cause. It only happens when I include |
You could try to add them to the header file with a patch in here:
https://github.com/saudet/tensorflow-java/blob/master/tensorflow-core/tensorflow-core-api/WORKSPACE
|
Let's ask upstream first before ad-hoc expanding the C API. There may be a reason those functions aren't part of the C API. Did you check to see if libtensorflow exports those symbols? |
It doesn't, some of them are static and others are in an anonymous namespace. The namespaced ones are just helpers that would be nice to have, the static ones ignore the graphs lock (which is required to define gradient functions via the C API). I can make an issue in tensorflow, but I'm not sure what to ask for other than a full custom gradient C API which wouldn't be worth doing for the old version. These are functions that shouldn't really be public, the C API just isn't made with custom gradients in mind. |
b39bd2f
to
d2ffa40
Compare
Ok, things work now, but it's a bit hacky. I'm going to make a tensorflow issue asking for the necessary functions, but even if they approve exporting them, we may want to merge this w/ the patch instead of waiting for 2.6 or whenever they make it in. I need three functions:
|
Also, can you mark this with CI Build? |
Signed-off-by: Ryan Nett <[email protected]>
Signed-off-by: Ryan Nett <[email protected]>
Signed-off-by: Ryan Nett <[email protected]>
Signed-off-by: Ryan Nett <[email protected]>
… Java Signed-off-by: Ryan Nett <[email protected]>
Signed-off-by: Ryan Nett <[email protected]>
Ok, I think the scopes should be named if possible, and the docs on |
Signed-off-by: Ryan Nett <[email protected]>
Signed-off-by: Ryan Nett <[email protected]>
Signed-off-by: Ryan Nett <[email protected]>
Can I get someone to re-run the CI jobs? The cache needs to be populated. |
Signed-off-by: Ryan Nett <[email protected]>
Signed-off-by: Ryan Nett <[email protected]>
Signed-off-by: Ryan Nett <[email protected]>
I think so. |
I'll document the rawtypes and then push the generation later today. |
Signed-off-by: Ryan Nett <[email protected]>
Signed-off-by: Ryan Nett <[email protected]>
All right, merging this now, thanks again for that great contribution, @rnett ! |
@karllessard @Craigacp generation is pushed, we're good to go. Edit: Welp I got ninja'd. |
This PR aims to add support for custom gradients for graphs, using the legacy gradient setup. Eventually it will be replaced by the gradient API in #283, but we have no idea when that will happen.