-
Notifications
You must be signed in to change notification settings - Fork 206
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Graph custom gradient support (#292)
- Loading branch information
Showing
1,367 changed files
with
17,408 additions
and
2,986 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
57 changes: 57 additions & 0 deletions
57
tensorflow-core/tensorflow-core-api/external/custom-grad-helpers.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc | ||
index f3bf7b98a1e6b..c9194c36c116b 100644 | ||
--- a/tensorflow/c/c_api.cc | ||
+++ b/tensorflow/c/c_api.cc | ||
@@ -782,9 +782,9 @@ void TF_GraphGetTensorShape(TF_Graph* graph, TF_Output output, int64_t* dims, | ||
|
||
extern "C" { | ||
|
||
-static TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, | ||
- const char* op_type, | ||
- const char* oper_name) | ||
+TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, | ||
+ const char* op_type, | ||
+ const char* oper_name) | ||
TF_EXCLUSIVE_LOCKS_REQUIRED(graph->mu) { | ||
return new TF_OperationDescription(graph, op_type, oper_name); | ||
} | ||
@@ -1041,8 +1041,8 @@ void TF_SetAttrValueProto(TF_OperationDescription* desc, const char* attr_name, | ||
status->status = Status::OK(); | ||
} | ||
|
||
-static TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, | ||
- TF_Status* status) | ||
+TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, | ||
+ TF_Status* status) | ||
TF_EXCLUSIVE_LOCKS_REQUIRED(desc->graph->mu) { | ||
Node* ret = nullptr; | ||
|
||
diff --git a/tensorflow/c/c_api.h b/tensorflow/c/c_api.h | ||
index 705cf85e0512f..fb746dd4af94f 100644 | ||
--- a/tensorflow/c/c_api.h | ||
+++ b/tensorflow/c/c_api.h | ||
@@ -255,6 +255,12 @@ TF_CAPI_EXPORT extern void TF_GraphGetTensorShape(TF_Graph* graph, | ||
int64_t* dims, int num_dims, | ||
TF_Status* status); | ||
|
||
+// TF_NewOperation, but without locking the graph. | ||
+// Should prefer TF_NewOperation when possible. | ||
+TF_CAPI_EXPORT extern TF_OperationDescription* TF_NewOperationLocked(TF_Graph* graph, | ||
+ const char* op_type, | ||
+ const char* oper_name); | ||
+ | ||
// Operation will only be added to *graph when TF_FinishOperation() is | ||
// called (assuming TF_FinishOperation() does not return an error). | ||
// *graph must not be deleted until after TF_FinishOperation() is | ||
@@ -406,6 +412,11 @@ TF_CAPI_EXPORT extern void TF_SetAttrValueProto(TF_OperationDescription* desc, | ||
size_t proto_len, | ||
TF_Status* status); | ||
|
||
+// TF_FinishOperation, but without locking the graph. | ||
+// TF_FinishOperation should be preferred when possible. | ||
+TF_CAPI_EXPORT extern TF_Operation* TF_FinishOperationLocked(TF_OperationDescription* desc, | ||
+ TF_Status* status); | ||
+ | ||
// If this function succeeds: | ||
// * *status is set to an OK value, | ||
// * a TF_Operation is added to the graph, |
151 changes: 151 additions & 0 deletions
151
tensorflow-core/tensorflow-core-api/external/custom-grad-symbols.patch
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
Index: tensorflow/tools/def_file_filter/BUILD | ||
IDEA additional info: | ||
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | ||
<+>UTF-8 | ||
=================================================================== | ||
diff --git a/tensorflow/tools/def_file_filter/BUILD b/tensorflow/tools/def_file_filter/BUILD | ||
--- a/tensorflow/tools/def_file_filter/BUILD (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad) | ||
+++ b/tensorflow/tools/def_file_filter/BUILD (date 1629063191558) | ||
@@ -12,3 +12,8 @@ | ||
name = "symbols_pybind", | ||
srcs = ["symbols_pybind.txt"], | ||
) | ||
+ | ||
+filegroup( | ||
+ name = "symbols_java", | ||
+ srcs = ["symbols_java.txt"], | ||
+) | ||
Index: tensorflow/BUILD | ||
IDEA additional info: | ||
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | ||
<+>UTF-8 | ||
=================================================================== | ||
diff --git a/tensorflow/BUILD b/tensorflow/BUILD | ||
--- a/tensorflow/BUILD (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad) | ||
+++ b/tensorflow/BUILD (date 1629063361078) | ||
@@ -1069,13 +1069,20 @@ | ||
# the dynamic libraries of custom ops can find it at runtime. | ||
genrule( | ||
name = "tensorflow_filtered_def_file", | ||
- srcs = [":tensorflow_def_file"], | ||
+ srcs = [ | ||
+ ":tensorflow_def_file", | ||
+ ":java_symbol_target_libs_file", | ||
+ ":win_lib_files_for_java_exported_symbols", | ||
+ "//tensorflow/tools/def_file_filter:symbols_java", | ||
+ ], | ||
outs = ["tensorflow_filtered_def_file.def"], | ||
cmd = select({ | ||
"//tensorflow:windows": """ | ||
$(location @local_config_def_file_filter//:def_file_filter) \\ | ||
--input $(location :tensorflow_def_file) \\ | ||
- --output $@ | ||
+ --output $@ \\ | ||
+ --symbols $(location //tensorflow/tools/def_file_filter:symbols_java) \\ | ||
+ --lib_paths_file $(location :java_symbol_target_libs_file) | ||
""", | ||
"//conditions:default": "touch $@", # Just a placeholder for Unix platforms | ||
}), | ||
@@ -1083,6 +1090,34 @@ | ||
visibility = ["//visibility:public"], | ||
) | ||
|
||
+# Write to a file a list of all cc_library targets that we need for exporting symbols on Windows. | ||
+genrule( | ||
+ name = "java_symbol_target_libs_file", | ||
+ srcs = [":win_lib_files_for_java_exported_symbols"], | ||
+ outs = ["java_symbol_target_libs_file.txt"], | ||
+ cmd = select({ | ||
+ "//tensorflow:windows": """ | ||
+ for SRC in $(SRCS); do | ||
+ echo $$SRC | sed 's/third_party\\///g' >> $@ | ||
+ done | ||
+ """, | ||
+ "//conditions:default": "touch $@", # Just a placeholder for Unix platforms | ||
+ }), | ||
+ visibility = ["//visibility:public"], | ||
+) | ||
+ | ||
+filegroup( | ||
+ name = "win_lib_files_for_java_exported_symbols", | ||
+ srcs = [ | ||
+ "//tensorflow/cc:scope", | ||
+ "//tensorflow/cc:grad_op_registry", | ||
+ "//tensorflow/c:tf_status_helper", | ||
+ "//tensorflow/cc:ops" | ||
+ ], | ||
+ visibility = ["//visibility:private"], | ||
+) | ||
+ | ||
+ | ||
# The interface library (tensorflow.dll.if.lib) for linking tensorflow DLL library (tensorflow.dll) on Windows. | ||
# To learn more about import library (called interface library in Bazel): | ||
# https://docs.microsoft.com/en-us/cpp/build/linking-an-executable-to-a-dll?view=vs-2017#linking-implicitly | ||
Index: tensorflow/tools/def_file_filter/BUILD.tpl | ||
IDEA additional info: | ||
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | ||
<+>UTF-8 | ||
=================================================================== | ||
diff --git a/tensorflow/tools/def_file_filter/BUILD.tpl b/tensorflow/tools/def_file_filter/BUILD.tpl | ||
--- a/tensorflow/tools/def_file_filter/BUILD.tpl (revision 5e5cc35b4c0f629a1e092b540fdf2b63367aa5ad) | ||
+++ b/tensorflow/tools/def_file_filter/BUILD.tpl (date 1629063191583) | ||
@@ -18,3 +18,8 @@ | ||
name = "symbols_pybind", | ||
srcs = ["symbols_pybind.txt"], | ||
) | ||
+ | ||
+filegroup( | ||
+ name = "symbols_java", | ||
+ srcs = ["symbols_java.txt"], | ||
+) | ||
Index: tensorflow/tools/def_file_filter/symbols_java.txt | ||
IDEA additional info: | ||
Subsystem: com.intellij.openapi.diff.impl.patch.CharsetEP | ||
<+>UTF-8 | ||
=================================================================== | ||
diff --git a/tensorflow/tools/def_file_filter/symbols_java.txt b/tensorflow/tools/def_file_filter/symbols_java.txt | ||
new file mode 100644 | ||
--- /dev/null (date 1629063607794) | ||
+++ b/tensorflow/tools/def_file_filter/symbols_java.txt (date 1629063607794) | ||
@@ -0,0 +1,26 @@ | ||
+[//tensorflow/cc:scope] # scope | ||
+tensorflow::Scope::graph | ||
+tensorflow::Scope::ok | ||
+tensorflow::Scope::UpdateBuilder | ||
+tensorflow::Scope::GetUniqueNameForOp | ||
+tensorflow::Scope::ExitOnError | ||
+tensorflow::Scope::WithDevice | ||
+tensorflow::Scope::WithNoControlDependencies | ||
+tensorflow::Scope::WithControlDependencies | ||
+tensorflow::Scope::NewSubScope | ||
+tensorflow::Scope::NewRootScope | ||
+tensorflow::Scope::operator= | ||
+tensorflow::Scope::~Scope | ||
+tensorflow::Scope::Scope | ||
+ | ||
+[//tensorflow/cc:ops] | ||
+tensorflow::Operation::Operation | ||
+ | ||
+[//tensorflow/cc:grad_op_registry] # custom gradients for graph | ||
+tensorflow::ops::GradOpRegistry::Global | ||
+tensorflow::ops::GradOpRegistry::Lookup | ||
+tensorflow::ops::GradOpRegistry::Register | ||
+ | ||
+[//tensorflow/c:tf_status_helper] # status helpers | ||
+tensorflow::Set_TF_Status_from_Status | ||
+tensorflow::StatusFromTF_Status | ||
=================================================================== | ||
diff --git a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl | ||
--- a/tensorflow/tools/def_file_filter/def_file_filter.py.tpl (revision 919f693420e35d00c8d0a42100837ae3718f7927) | ||
+++ b/tensorflow/tools/def_file_filter/def_file_filter.py.tpl (date 1632048268359) | ||
@@ -143,8 +143,8 @@ | ||
re_filter_comp = re.compile(r"{}".format(re_filter)) | ||
|
||
# Filter out symbol from the split line (`sym_split` in the for loop below). | ||
- sym_line_filter = r".*\s+\| (.*) \(.*" | ||
- sym_line_filter_anomaly = r".*\s+\| (.*)" | ||
+ sym_line_filter = r".*\s+\| (.*?) \(.*" | ||
+ sym_line_filter_anomaly = r".*\s+\| (.*?)" | ||
|
||
for sym_line in sym_split: | ||
if re_filter_comp.search(sym_line): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
25 changes: 25 additions & 0 deletions
25
tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/GradFunc.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE | ||
|
||
package org.tensorflow.internal.c_api; | ||
|
||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.tensorflow.internal.c_api.global.tensorflow.*; | ||
|
||
|
||
/** GradFunc is the signature for all gradient functions in GradOpRegistry. | ||
* Implementations should add operations to compute the gradient outputs of | ||
* 'op' (returned in 'grad_outputs') using 'scope' and 'grad_inputs'. */ | ||
@Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) | ||
public class GradFunc extends FunctionPointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public GradFunc(Pointer p) { super(p); } | ||
protected GradFunc() { allocate(); } | ||
private native void allocate(); | ||
public native @ByVal NativeStatus call(@Const @ByRef TF_Scope scope, @Const @ByRef NativeOperation op, | ||
@Const @ByRef NativeOutputVector grad_inputs, | ||
NativeOutputVector grad_outputs); | ||
} |
48 changes: 48 additions & 0 deletions
48
...w-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/GradOpRegistry.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE | ||
|
||
package org.tensorflow.internal.c_api; | ||
|
||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.tensorflow.internal.c_api.global.tensorflow.*; | ||
|
||
|
||
/** GradOpRegistry maintains a static registry of gradient functions. | ||
* Gradient functions are indexed in the registry by the forward op name (i.e. | ||
* "MatMul" -> MatMulGrad func). */ | ||
@Namespace("tensorflow::ops") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) | ||
public class GradOpRegistry extends Pointer { | ||
static { Loader.load(); } | ||
/** Default native constructor. */ | ||
public GradOpRegistry() { super((Pointer)null); allocate(); } | ||
/** Native array allocator. Access with {@link Pointer#position(long)}. */ | ||
public GradOpRegistry(long size) { super((Pointer)null); allocateArray(size); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public GradOpRegistry(Pointer p) { super(p); } | ||
private native void allocate(); | ||
private native void allocateArray(long size); | ||
@Override public GradOpRegistry position(long position) { | ||
return (GradOpRegistry)super.position(position); | ||
} | ||
@Override public GradOpRegistry getPointer(long i) { | ||
return new GradOpRegistry((Pointer)this).offsetAddress(i); | ||
} | ||
|
||
/** Registers 'func' as the gradient function for 'op'. | ||
* Returns true if registration was successful, check fails otherwise. */ | ||
public native @Cast("bool") boolean Register(@StdString BytePointer op, GradFunc func); | ||
public native @Cast("bool") boolean Register(@StdString String op, GradFunc func); | ||
|
||
/** Sets 'func' to the gradient function for 'op' and returns Status OK if | ||
* the gradient function for 'op' exists in the registry. | ||
* Note that 'func' can be null for ops that have registered no-gradient with | ||
* the registry. | ||
* Returns error status otherwise. */ | ||
public native @ByVal NativeStatus Lookup(@StdString BytePointer op, @ByPtrPtr GradFunc func); | ||
public native @ByVal NativeStatus Lookup(@StdString String op, @ByPtrPtr GradFunc func); | ||
|
||
/** Returns a pointer to the global gradient function registry. */ | ||
public static native GradOpRegistry Global(); | ||
} |
41 changes: 41 additions & 0 deletions
41
tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/internal/c_api/NameMap.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
// Targeted by JavaCPP version 1.5.6: DO NOT EDIT THIS FILE | ||
|
||
package org.tensorflow.internal.c_api; | ||
|
||
import java.nio.*; | ||
import org.bytedeco.javacpp.*; | ||
import org.bytedeco.javacpp.annotation.*; | ||
|
||
import static org.tensorflow.internal.c_api.global.tensorflow.*; | ||
|
||
@Name("std::unordered_map<tensorflow::string,tensorflow::Node*>") @Properties(inherit = org.tensorflow.internal.c_api.presets.tensorflow.class) | ||
public class NameMap extends Pointer { | ||
static { Loader.load(); } | ||
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ | ||
public NameMap(Pointer p) { super(p); } | ||
public NameMap() { allocate(); } | ||
private native void allocate(); | ||
public native @Name("operator =") @ByRef NameMap put(@ByRef NameMap x); | ||
|
||
public boolean empty() { return size() == 0; } | ||
public native long size(); | ||
|
||
@Index public native Node get(@StdString BytePointer i); | ||
public native NameMap put(@StdString BytePointer i, Node value); | ||
|
||
public native void erase(@ByVal Iterator pos); | ||
public native @ByVal Iterator begin(); | ||
public native @ByVal Iterator end(); | ||
@NoOffset @Name("iterator") public static class Iterator extends Pointer { | ||
public Iterator(Pointer p) { super(p); } | ||
public Iterator() { } | ||
|
||
public native @Name("operator ++") @ByRef Iterator increment(); | ||
public native @Name("operator ==") boolean equals(@ByRef Iterator it); | ||
public native @Name("operator *().first") @MemberGetter @StdString BytePointer first(); | ||
public native @Name("operator *().second") @MemberGetter @Const Node second(); | ||
} | ||
|
||
public native long erase(@StdString BytePointer key); | ||
} | ||
|
Oops, something went wrong.