Skip to content

Commit

Permalink
Graph custom gradient support (#292)
Browse files Browse the repository at this point in the history
  • Loading branch information
rnett authored Nov 11, 2021
1 parent 55547dd commit e0eec4a
Show file tree
Hide file tree
Showing 1,367 changed files with 17,408 additions and 2,986 deletions.
2 changes: 2 additions & 0 deletions tensorflow-core/tensorflow-core-api/WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ http_archive(
# ":tensorflow-macosx.patch",
# ":tensorflow-windows.patch", # https://github.com/tensorflow/tensorflow/issues/25213
":tensorflow-proto.patch",
":custom-grad-helpers.patch",
":custom-grad-symbols.patch",
],
patch_tool = "patch",
patch_args = ["-p1"],
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index f3bf7b98a1e6b..c9194c36c116b 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -782,9 +782,9 @@ void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims,

extern "C" {

-static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
- const char* op_type,
- const char* oper_name)
+TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
+ const char* op_type,
+ const char* oper_name)
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
return new TF_OperationDescription(graph, op_type, oper_name);
}
@@ -1041,8 +1041,8 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name,
status->status = Status::OK();
}

-static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
- TF_Status* status)
+TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
+ TF_Status* status)
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) {
Node* ret = nullptr;

diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h
index 705cf85e0512f..fb746dd4af94f 100644
--- a/tensorflow/c/c_api.h
+++ b/tensorflow/c/c_api.h
@@ -255,6 +255,12 @@ TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph,
int64_t* dims, int num_dims,
TF_Status* status);

+// TF_NewOperation, but without locking the graph.
+// Should prefer TF_NewOperation when possible.
+TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph,
+ const char* op_type,
+ const char* oper_name);
+
// Operation will only be added to *graph when TF_FinishOperation() is
// called (assuming TF_FinishOperation() does not return an error).
// *graph must not be deleted until after TF_FinishOperation() is
@@ -406,6 +412,11 @@ TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc,
size_t proto_len,
TF_Status* status);

+// TF_FinishOperation, but without locking the graph.
+// TF_FinishOperation should be preferred when possible.
+TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc,
+ TF_Status* status);
+
// If this function succeeds:
// * *status is set to an OK value,
// * a TF_Operation is added to the graph,
151 changes: 151 additions & 0 deletions tensorflow-core/tensorflow-core-api/external/custom-grad-symbols.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
Index: tensorflow/tools/def_file_filter/BUILD
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/tensorflow/tools/def_file_filter/BUILD b/tensorflow/tools/def_file_filter/BUILD
--- a/tensorflow/tools/def_file_filter/BUILD (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad)
+++ b/tensorflow/tools/def_file_filter/BUILD (date 1629063191558)
@@ -12,3 +12,8 @@
name = "symbols_pybind",
srcs = ["symbols_pybind.txt"],
)
+
+filegroup(
+ name = "symbols_java",
+ srcs = ["symbols_java.txt"],
+)
Index: tensorflow/BUILD
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
--- a/tensorflow/BUILD (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad)
+++ b/tensorflow/BUILD (date 1629063361078)
@@ -1069,13 +1069,20 @@
# the dynamic libraries of custom ops can find it at runtime.
genrule(
name = "tensorflow_filtered_def_file",
- srcs = [":tensorflow_def_file"],
+ srcs = [
+ ":tensorflow_def_file",
+ ":java_symbol_target_libs_file",
+ ":win_lib_files_for_java_exported_symbols",
+ "//tensorflow/tools/def_file_filter:symbols_java",
+ ],
outs = ["tensorflow_filtered_def_file.def"],
cmd = select({
"//tensorflow:windows": """
$(location @local_config_def_file_filter//:def_file_filter) \\
--input $(location :tensorflow_def_file) \\
- --output $@
+ --output $@ \\
+ --symbols $(location //tensorflow/tools/def_file_filter:symbols_java) \\
+ --lib_paths_file $(location :java_symbol_target_libs_file)
""",
"//conditions:default": "touch $@", # Just a placeholder for Unix platforms
}),
@@ -1083,6 +1090,34 @@
visibility = ["//visibility:public"],
)

+# Write to a file a list of all cc_library targets that we need for exporting symbols on Windows.
+genrule(
+ name = "java_symbol_target_libs_file",
+ srcs = [":win_lib_files_for_java_exported_symbols"],
+ outs = ["java_symbol_target_libs_file.txt"],
+ cmd = select({
+ "//tensorflow:windows": """
+ for SRC in $(SRCS); do
+ echo $$SRC | sed 's/third_party\\///g' >> $@
+ done
+ """,
+ "//conditions:default": "touch $@", # Just a placeholder for Unix platforms
+ }),
+ visibility = ["//visibility:public"],
+)
+
+filegroup(
+ name = "win_lib_files_for_java_exported_symbols",
+ srcs = [
+ "//tensorflow/cc:scope",
+ "//tensorflow/cc:grad_op_registry",
+ "//tensorflow/c:tf_status_helper",
+ "//tensorflow/cc:ops"
+ ],
+ visibility = ["//visibility:private"],
+)
+
+
# The interface library (tensorflow.dll.if.lib) for linking tensorflow DLL library (tensorflow.dll) on Windows.
# To learn more about import library (called interface library in Bazel):
# https://docs.microsoft.com/en-us/cpp/build/linking-an-executable-to-a-dll?view=vs-2017#linking-implicitly
Index: tensorflow/tools/def_file_filter/BUILD.tpl
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/tensorflow/tools/def_file_filter/BUILD.tpl b/tensorflow/tools/def_file_filter/BUILD.tpl
--- a/tensorflow/tools/def_file_filter/BUILD.tpl (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad)
+++ b/tensorflow/tools/def_file_filter/BUILD.tpl (date 1629063191583)
@@ -18,3 +18,8 @@
name = "symbols_pybind",
srcs = ["symbols_pybind.txt"],
)
+
+filegroup(
+ name = "symbols_java",
+ srcs = ["symbols_java.txt"],
+)
Index: tensorflow/tools/def_file_filter/symbols_java.txt
IDEA additional info:
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP
<+>UTF-8
===================================================================
diff --git a/tensorflow/tools/def_file_filter/symbols_java.txt b/tensorflow/tools/def_file_filter/symbols_java.txt
new file mode 100644
--- /dev/null (date 1629063607794)
+++ b/tensorflow/tools/def_file_filter/symbols_java.txt (date 1629063607794)
@@ -0,0 +1,26 @@
+[//tensorflow/cc:scope] # scope
+tensorflow::Scope::graph
+tensorflow::Scope::ok
+tensorflow::Scope::UpdateBuilder
+tensorflow::Scope::GetUniqueNameForOp
+tensorflow::Scope::ExitOnError
+tensorflow::Scope::WithDevice
+tensorflow::Scope::WithNoControlDependencies
+tensorflow::Scope::WithControlDependencies
+tensorflow::Scope::NewSubScope
+tensorflow::Scope::NewRootScope
+tensorflow::Scope::operator=
+tensorflow::Scope::~Scope
+tensorflow::Scope::Scope
+
+[//tensorflow/cc:ops]
+tensorflow::Operation::Operation
+
+[//tensorflow/cc:grad_op_registry] # custom gradients for graph
+tensorflow::ops::GradOpRegistry::Global
+tensorflow::ops::GradOpRegistry::Lookup
+tensorflow::ops::GradOpRegistry::Register
+
+[//tensorflow/c:tf_status_helper] # status helpers
+tensorflow::Set_TF_Status_from_Status
+tensorflow::StatusFromTF_Status
===================================================================
diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl
--- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl (revision 919f693420e35d00c8d0a42100837ae3718f7927)
+++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl (date 1632048268359)
@@ -143,8 +143,8 @@
re_filter_comp = re.compile(r"{}".format(re_filter))

# Filter out symbol from the split line (`sym_split` in the for loop below).
- sym_line_filter = r".*\s+\| (.*) \(.*"
- sym_line_filter_anomaly = r".*\s+\| (.*)"
+ sym_line_filter = r".*\s+\| (.*?) \(.*"
+ sym_line_filter_anomaly = r".*\s+\| (.*?)"

for sym_line in sym_split:
if re_filter_comp.search(sym_line):
18 changes: 18 additions & 0 deletions tensorflow-core/tensorflow-core-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,19 @@
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-resources-plugin</artifactId>
<version>3.1.0</version>
<executions>
<execution>
<id>javacpp-parser</id>
<phase>generate-sources</phase>
<goals>
<goal>resources</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.0</version>
Expand Down Expand Up @@ -212,6 +225,11 @@
<includePaths>
<includePath>${project.basedir}/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/org_tensorflow/</includePath>
<includePath>${project.basedir}/bazel-bin/external/org_tensorflow/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_absl/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/eigen_archive/</includePath>
<includePath>${project.basedir}/bazel-${project.artifactId}/external/com_google_protobuf/src/</includePath>
<includePath>${project.basedir}/target/classes/org/tensorflow/internal/c_api/include/</includePath>
</includePaths>
<linkPaths>
<linkPath>${project.basedir}/bazel-bin/external/llvm_openmp/</linkPath>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -367,10 +367,10 @@ public final class Ops {

public final SparseOps sparse;

public final BitwiseOps bitwise;

public final TpuOps tpu;

public final BitwiseOps bitwise;

public final MathOps math;

public final AudioOps audio;
Expand All @@ -383,7 +383,7 @@ public final class Ops {

private final Scope scope;

private Ops(Scope scope) {
Ops(Scope scope) {
this.scope = scope;
nn = new NnOps(this);
summary = new SummaryOps(this);
Expand All @@ -398,8 +398,8 @@ private Ops(Scope scope) {
random = new RandomOps(this);
strings = new StringsOps(this);
sparse = new SparseOps(this);
bitwise = new BitwiseOps(this);
tpu = new TpuOps(this);
bitwise = new BitwiseOps(this);
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;


/** GradFunc is the signature for all gradient functions in GradOpRegistry.
* Implementations should add operations to compute the gradient outputs of
* 'op' (returned in 'grad_outputs') using 'scope' and 'grad_inputs'. */
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class GradFunc extends FunctionPointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public GradFunc(Pointer p) { super(p); }
protected GradFunc() { allocate(); }
private native void allocate();
public native @ByVal NativeStatus call(@Const @ByRef TF_Scope scope, @Const @ByRef NativeOperation op,
@Const @ByRef NativeOutputVector grad_inputs,
NativeOutputVector grad_outputs);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;


/** GradOpRegistry maintains a static registry of gradient functions.
* Gradient functions are indexed in the registry by the forward op name (i.e.
* "MatMul" -> MatMulGrad func). */
@Namespace("tensorflow::ops") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class GradOpRegistry extends Pointer {
static { Loader.load(); }
/** Default native constructor. */
public GradOpRegistry() { super((Pointer)null); allocate(); }
/** Native array allocator. Access with {@link Pointer#position(long)}. */
public GradOpRegistry(long size) { super((Pointer)null); allocateArray(size); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public GradOpRegistry(Pointer p) { super(p); }
private native void allocate();
private native void allocateArray(long size);
@Override public GradOpRegistry position(long position) {
return (GradOpRegistry)super.position(position);
}
@Override public GradOpRegistry getPointer(long i) {
return new GradOpRegistry((Pointer)this).offsetAddress(i);
}

/** Registers 'func' as the gradient function for 'op'.
* Returns true if registration was successful, check fails otherwise. */
public native @Cast("bool") boolean Register(@StdString BytePointer op, GradFunc func);
public native @Cast("bool") boolean Register(@StdString String op, GradFunc func);

/** Sets 'func' to the gradient function for 'op' and returns Status OK if
* the gradient function for 'op' exists in the registry.
* Note that 'func' can be null for ops that have registered no-gradient with
* the registry.
* Returns error status otherwise. */
public native @ByVal NativeStatus Lookup(@StdString BytePointer op, @ByPtrPtr GradFunc func);
public native @ByVal NativeStatus Lookup(@StdString String op, @ByPtrPtr GradFunc func);

/** Returns a pointer to the global gradient function registry. */
public static native GradOpRegistry Global();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE

package org.tensorflow.internal.c_api;

import java.nio.*;
import org.bytedeco.javacpp.*;
import org.bytedeco.javacpp.annotation.*;

import static org.tensorflow.internal.c_api.global.tensorflow.*;

@Name("std::unordered_map<tensorflow::string,tensorflow::Node*>") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class)
public class NameMap extends Pointer {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public NameMap(Pointer p) { super(p); }
public NameMap() { allocate(); }
private native void allocate();
public native @Name("operator =") @ByRef NameMap put(@ByRef NameMap x);

public boolean empty() { return size() == 0; }
public native long size();

@Index public native Node get(@StdString BytePointer i);
public native NameMap put(@StdString BytePointer i, Node value);

public native void erase(@ByVal Iterator pos);
public native @ByVal Iterator begin();
public native @ByVal Iterator end();
@NoOffset @Name("iterator") public static class Iterator extends Pointer {
public Iterator(Pointer p) { super(p); }
public Iterator() { }

public native @Name("operator ++") @ByRef Iterator increment();
public native @Name("operator ==") boolean equals(@ByRef Iterator it);
public native @Name("operator *().first") @MemberGetter @StdString BytePointer first();
public native @Name("operator *().second") @MemberGetter @Const Node second();
}

public native long erase(@StdString BytePointer key);
}

Loading

0 comments on commit e0eec4a

Please sign in to comment.