Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph custom gradient support #292

Merged
merged 27 commits into from
Nov 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
ee96979
Add JavaCPP generation for gradient registry stuff
rnett Apr 18, 2021
c5717b0
Working gradients
rnett Apr 28, 2021
12d579c
Add missing requireHandles
rnett Apr 28, 2021
1f1ba2c
Rebase, use try-with-resources
rnett Apr 28, 2021
89690ab
Expose the necessary symbols on Windows
rnett Apr 29, 2021
dfaac8d
Nicely handle pre-existing gradients
rnett Sep 19, 2021
2f22b57
Small-ish review changes
rnett Oct 13, 2021
3acc001
Use annotation instead of field reflection to store op types
rnett Oct 13, 2021
5127413
Cleanup and more review changes
rnett Oct 13, 2021
293fecc
Remove empty init file
rnett Oct 29, 2021
2eb9342
Update annotation generator comments
rnett Oct 29, 2021
f1f6e8b
Update annotation names and comments, and registerCustomGradient javadoc
rnett Oct 31, 2021
8c40b8c
Add no-arg ctor to BaseGradientAdapter
rnett Nov 5, 2021
d08b38b
Add documentation about dangerousGradientBuilder
rnett Nov 5, 2021
72ed4f0
Add Javadoc for getUnsafeNativeHandle
rnett Nov 5, 2021
fd2609d
More dangerous gradient builder javadocs
rnett Nov 5, 2021
36a6e30
Add note about why gradientFuncs is required
rnett Nov 5, 2021
2ddbb6c
Store and allow getting native scope device when it has been set from…
rnett Nov 5, 2021
759a754
Rename withDevice's parameter
rnett Nov 5, 2021
ed1da29
Update scope for fix review comments
rnett Nov 5, 2021
4504a6f
Clarify the difference between CustomGradient and RawCustomGradient
rnett Nov 5, 2021
be4840c
Remove experiment
rnett Nov 5, 2021
9601138
Adjust GraphOperation#input to not require a graph lock
rnett Nov 8, 2021
ca5d343
Remove printing from CustomGradientTest
rnett Nov 8, 2021
517fd8d
Cleanup adapter exceptions, name gradient scopes
rnett Nov 8, 2021
f5dd0e5
Document use of rawtypes
rnett Nov 11, 2021
bfbfb49
Generate the new op classes
rnett Nov 11, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
2 changes: 2 additions & 0 deletions tensorflow-core/tensorflow-core-api/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ http_archive(
# ":tensorflow-macosx.patch",
# ":tensorflow-windows.patch", # https://github.com/tensorflow/tensorflow/issues/25213
":tensorflow-proto.patch",
":custom-grad-helpers.patch",
":custom-grad-symbols.patch",
],
patch_tool = "patch",
patch_args = ["-p1"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index f3bf7b98a1e6b..c9194c36c116b 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -782,9 +782,9 @@ void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,

extern "C" {

-static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
- const char* op_type,
- const char* oper_name)
+TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
+ const char* op_type,
+ const char* oper_name)
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
return new TF_OperationDescription(graph, op_type, oper_name);
}
@@ -1041,8 +1041,8 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
status->status = Status::OK();
}

-static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
- TF_Status* status)
+TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
+ TF_Status* status)
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
Node* ret = nullptr;

diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 705cf85e0512f..fb746dd4af94f 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -255,6 +255,12 @@ TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph,
int64_t* dims, int num_dims,
TF_Status* status);

+// TF_NewOperation, but without locking the graph.
+// Should prefer TF_NewOperation when possible.
+TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
+ const char* op_type,
+ const char* oper_name);
+
// Operation will only be added to *graph when TF_FinishOperation() is
// called (assuming TF_FinishOperation() does not return an error).
// *graph must not be deleted until after TF_FinishOperation() is
@@ -406,6 +412,11 @@ TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc,
size_t proto_len,
TF_Status* status);

+// TF_FinishOperation, but without locking the graph.
+// TF_FinishOperation should be preferred when possible.
+TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
+ TF_Status* status);
+
// If this function succeeds:
// * *status is set to an OK value,
// * a TF_Operation is added to the graph,
151 changes: 151 additions & 0 deletions tensorflow-core/tensorflow-core-api/external/custom-grad-symbols.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
Index: tensorflow/tools/def_file_filter/BUILD
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/tensorflow/tools/def_file_filter/BUILD b/tensorflow/tools/def_file_filter/BUILD
--- a/tensorflow/tools/def_file_filter/BUILD (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad)
+++ b/tensorflow/tools/def_file_filter/BUILD (date 1629063191558)
@@ -12,3 +12,8 @@
name = "symbols_pybind",
srcs = ["symbols_pybind.txt"],
)
+
+filegroup(
+ name = "symbols_java",
+ srcs = ["symbols_java.txt"],
+)
Index: tensorflow/BUILD
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
--- a/tensorflow/BUILD (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad)
+++ b/tensorflow/BUILD (date 1629063361078)
@@ -1069,13 +1069,20 @@
# the dynamic libraries of custom ops can find it at runtime.
genrule(
name = "tensorflow_filtered_def_file",
- srcs = [":tensorflow_def_file"],
+ srcs = [
+ ":tensorflow_def_file",
+ ":java_symbol_target_libs_file",
+ ":win_lib_files_for_java_exported_symbols",
+ "//tensorflow/tools/def_file_filter:symbols_java",
+ ],
outs = ["tensorflow_filtered_def_file.def"],
cmd = select({
"//tensorflow:windows": """
$(location @local_config_def_file_filter//:def_file_filter) \\
--input $(location :tensorflow_def_file) \\
- --output $@
+ --output $@ \\
+ --symbols $(location //tensorflow/tools/def_file_filter:symbols_java) \\
+ --lib_paths_file $(location :java_symbol_target_libs_file)
""",
"//conditions:default": "touch $@", # Just a placeholder for Unix platforms
}),
@@ -1083,6 +1090,34 @@
visibility = ["//visibility:public"],
)

+# Write to a file a list of all cc_library targets that we need for exporting symbols on Windows.
+genrule(
+ name = "java_symbol_target_libs_file",
+ srcs = [":win_lib_files_for_java_exported_symbols"],
+ outs = ["java_symbol_target_libs_file.txt"],
+ cmd = select({
+ "//tensorflow:windows": """
+ for SRC in $(SRCS); do
+ echo $$SRC | sed 's/third_party\\///g' >> $@
+ done
+ """,
+ "//conditions:default": "touch $@", # Just a placeholder for Unix platforms
+ }),
+ visibility = ["//visibility:public"],
+)
+
+filegroup(
+ name = "win_lib_files_for_java_exported_symbols",
+ srcs = [
+ "//tensorflow/cc:scope",
+ "//tensorflow/cc:grad_op_registry",
+ "//tensorflow/c:tf_status_helper",
+ "//tensorflow/cc:ops"
+ ],
+ visibility = ["//visibility:private"],
+)
+
+
# The interface library (tensorflow.dll.if.lib) for linking tensorflow DLL library (tensorflow.dll) on Windows.
# To learn more about import library (called interface library in Bazel):
# https://docs.microsoft.com/en-us/cpp/build/linking-an-executable-to-a-dll?view=vs-2017#linking-implicitly
Index: tensorflow/tools/def_file_filter/BUILD.tpl
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/tensorflow/tools/def_file_filter/BUILD.tpl b/tensorflow/tools/def_file_filter/BUILD.tpl
--- a/tensorflow/tools/def_file_filter/BUILD.tpl (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad)
+++ b/tensorflow/tools/def_file_filter/BUILD.tpl (date 1629063191583)
@@ -18,3 +18,8 @@
name = "symbols_pybind",
srcs = ["symbols_pybind.txt"],
)
+
+filegroup(
+ name = "symbols_java",
+ srcs = ["symbols_java.txt"],
+)
Index: tensorflow/tools/def_file_filter/symbols_java.txt
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/tensorflow/tools/def_file_filter/symbols_java.txt b/tensorflow/tools/def_file_filter/symbols_java.txt
new file mode 100644
--- /dev/null (date 1629063607794)
+++ b/tensorflow/tools/def_file_filter/symbols_java.txt (date 1629063607794)
@@ -0,0 +1,26 @@
+[//tensorflow/cc:scope] # scope
+tensorflow::Scope::graph
+tensorflow::Scope::ok
+tensorflow::Scope::UpdateBuilder
+tensorflow::Scope::GetUniqueNameForOp
+tensorflow::Scope::ExitOnError
+tensorflow::Scope::WithDevice
+tensorflow::Scope::WithNoControlDependencies
+tensorflow::Scope::WithControlDependencies
+tensorflow::Scope::NewSubScope
+tensorflow::Scope::NewRootScope
+tensorflow::Scope::operator=
+tensorflow::Scope::~Scope
+tensorflow::Scope::Scope
+
+[//tensorflow/cc:ops]
+tensorflow::Operation::Operation
+
+[//tensorflow/cc:grad_op_registry] # custom gradients for graph
+tensorflow::ops::GradOpRegistry::Global
+tensorflow::ops::GradOpRegistry::Lookup
+tensorflow::ops::GradOpRegistry::Register
+
+[//tensorflow/c:tf_status_helper] # status helpers
+tensorflow::Set_TF_Status_from_Status
+tensorflow::StatusFromTF_Status
===================================================================
diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
--- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl (revision 919f693420e35d00c8d0a42100837ae3718f7927)
+++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl (date 1632048268359)
@@ -143,8 +143,8 @@
re_filter_comp = re.compile(r"{}".format(re_filter))

# Filter out symbol from the split line (`sym_split` in the for loop below).
- sym_line_filter = r".*\s+\| (.*) \(.*"
- sym_line_filter_anomaly = r".*\s+\| (.*)"
+ sym_line_filter = r".*\s+\| (.*?) \(.*"
+ sym_line_filter_anomaly = r".*\s+\| (.*?)"

for sym_line in sym_split:
if re_filter_comp.search(sym_line):
18 changes: 18 additions & 0 deletions tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.1.0</version>
<executions>
<execution>
<id>javacpp-parser</id>
<phase>generate-sources</phase>
<goals>
<goal>resources</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.0</version>
Expand Down Expand Up @@ -212,6 +225,11 @@
<includePaths>
<includePath>${project.basedir}/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/</includePath>
<includePath>${project.basedir}/bazel-bin/external/org_tensorflow/</includePath>
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_absl/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/eigen_archive/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_protobuf/src/</includePath>
<includePath>${project.basedir}/target/classes/org/tensorflow/internal/c_api/include/</includePath>
</includePaths>
<linkPaths>
<linkPath>${project.basedir}/bazel-bin/external/llvm_openmp/</linkPath>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,10 @@ public final class Ops {

public final SparseOps sparse;

public final BitwiseOps bitwise;

public final TpuOps tpu;

public final BitwiseOps bitwise;

public final MathOps math;

public final AudioOps audio;
Expand All @@ -383,7 +383,7 @@ public final class Ops {

private final Scope scope;

private Ops(Scope scope) {
Ops(Scope scope) {
this.scope = scope;
nn = new NnOps(this);
summary = new SummaryOps(this);
Expand All @@ -398,8 +398,8 @@ private Ops(Scope scope) {
random = new RandomOps(this);
strings = new StringsOps(this);
sparse = new SparseOps(this);
bitwise = new BitwiseOps(this);
tpu = new TpuOps(this);
bitwise = new BitwiseOps(this);
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;


/** GradFunc is the signature for all gradient functions in GradOpRegistry.
* Implementations should add operations to compute the gradient outputs of
* 'op' (returned in 'grad_outputs') using 'scope' and 'grad_inputs'. */
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class GradFunc extends FunctionPointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public GradFunc(Pointer p) { super(p); }
protected GradFunc() { allocate(); }
private native void allocate();
public native @ByVal NativeStatus call(@Const @ByRef TF_Scope scope, @Const @ByRef NativeOperation op,
@Const @ByRef NativeOutputVector grad_inputs,
NativeOutputVector grad_outputs);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;


/** GradOpRegistry maintains a static registry of gradient functions.
* Gradient functions are indexed in the registry by the forward op name (i.e.
* "MatMul" -> MatMulGrad func). */
@Namespace("tensorflow::ops") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class GradOpRegistry extends Pointer {
static { Loader.load(); }
/** Default native constructor. */
public GradOpRegistry() { super((Pointer)null); allocate(); }
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public GradOpRegistry(long size) { super((Pointer)null); allocateArray(size); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public GradOpRegistry(Pointer p) { super(p); }
private native void allocate();
private native void allocateArray(long size);
@Override public GradOpRegistry position(long position) {
return (GradOpRegistry)super.position(position);
}
@Override public GradOpRegistry getPointer(long i) {
return new GradOpRegistry((Pointer)this).offsetAddress(i);
}

/** Registers 'func' as the gradient function for 'op'.
* Returns true if registration was successful, check fails otherwise. */
public native @Cast("bool") boolean Register(@StdString BytePointer op, GradFunc func);
public native @Cast("bool") boolean Register(@StdString String op, GradFunc func);

/** Sets 'func' to the gradient function for 'op' and returns Status OK if
* the gradient function for 'op' exists in the registry.
* Note that 'func' can be null for ops that have registered no-gradient with
* the registry.
* Returns error status otherwise. */
public native @ByVal NativeStatus Lookup(@StdString BytePointer op, @ByPtrPtr GradFunc func);
public native @ByVal NativeStatus Lookup(@StdString String op, @ByPtrPtr GradFunc func);

/** Returns a pointer to the global gradient function registry. */
public static native GradOpRegistry Global();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;

@Name("std::unordered_map<tensorflow::string,tensorflow::Node*>") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class NameMap extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public NameMap(Pointer p) { super(p); }
public NameMap() { allocate(); }
Craigacp marked this conversation as resolved.
Show resolved Hide resolved
private native void allocate();
public native @Name("operator =") @ByRef NameMap put(@ByRef NameMap x);

public boolean empty() { return size() == 0; }
public native long size();

@Index public native Node get(@StdString BytePointer i);
public native NameMap put(@StdString BytePointer i, Node value);

public native void erase(@ByVal Iterator pos);
public native @ByVal Iterator begin();
public native @ByVal Iterator end();
@NoOffset @Name("iterator") public static class Iterator extends Pointer {
public Iterator(Pointer p) { super(p); }
public Iterator() { }

public native @Name("operator ++") @ByRef Iterator increment();
public native @Name("operator ==") boolean equals(@ByRef Iterator it);
public native @Name("operator *().first") @MemberGetter @StdString BytePointer first();
public native @Name("operator *().second") @MemberGetter @Const Node second();
}

public native long erase(@StdString BytePointer key);
}

Loading