Skip to content

Commit

Permalink
Use better conversion functions
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Nett <[email protected]>
  • Loading branch information
rnett committed Apr 29, 2021
1 parent 59750dc commit b39bd2f
Showing 1 changed file with 35 additions and 23 deletions.
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
/*
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================
*/
package org.tensorflow;

import static org.tensorflow.internal.c_api.global.tensorflow.OutputsFromTFOutputs;
import static org.tensorflow.internal.c_api.global.tensorflow.TFOutputsFromOutputs;
import static org.tensorflow.internal.c_api.global.tensorflow.ToOperation;

import java.util.ArrayList;
import java.util.List;
import org.bytedeco.javacpp.PointerScope;
import org.tensorflow.internal.c_api.NativeOutput;
import org.tensorflow.internal.c_api.NativeOutputVector;
import org.tensorflow.internal.c_api.Node;
import org.tensorflow.internal.c_api.TF_Output;

/**
* Helpers for {@link org.tensorflow.op.TypedGradientAdapter} and {@link
Expand All @@ -34,39 +36,49 @@ public class GradientAdapterHelpers {
/**
* Convert a array of native outputs to a list of {@link Output}s.
*
* @param g the graph the outputs are in
* @param g the graph the outputs are in
* @param nativeOutputs the native outputs to convert
*/
public static List<Output<?>> fromNativeOutputs(Graph g, NativeOutputVector nativeOutputs) {
TF_Output outputs = new TF_Output(nativeOutputs.size());
TFOutputsFromOutputs(nativeOutputs, outputs);
List<Output<?>> gradInputs = new ArrayList<>((int) nativeOutputs.size());
for (int i = 0; i < nativeOutputs.size(); i++) {
NativeOutput output = nativeOutputs.get(i);
gradInputs.add(new Output<>(getGraphOp(g, output.node()),
output.index()));
TF_Output output = outputs.position(i);
gradInputs.add(new Output<>(new GraphOperation(g, output.oper()), output.index()));
}
return gradInputs;
}

/**
* Put the Java outputs into the array of native outputs, resizing it to the necessary size.
*
* @param outputs the outputs to put
* @param outputs the outputs to put
* @param nativeOutputs the native array to put the outputs into
*/
public static void putToNativeOutputs(List<Operand<?>> outputs,
NativeOutputVector nativeOutputs) {
public static void putToNativeOutputs(
List<Operand<?>> outputs, NativeOutputVector nativeOutputs) {
nativeOutputs.resize(outputs.size());

TF_Output tempOutputs = new TF_Output(outputs.size());
for (int i = 0; i < outputs.size(); i++) {
Output<?> output = outputs.get(i).asOutput();
Node node = ((GraphOperation) output.op()).getUnsafeNativeHandle().node();
nativeOutputs.put(i, new NativeOutput(node, output.index()));
GraphOperation graphOp = (GraphOperation) output.op();
tempOutputs
.position(i)
.put(new TF_Output().oper(graphOp.getUnsafeNativeHandle()).index(output.index()));
}

NativeOutputVector temp = OutputsFromTFOutputs(tempOutputs, outputs.size());
for (int i = 0; i < outputs.size(); i++) {
nativeOutputs.put(i, temp.get(i));
}
}

/**
* Make a {@link GraphOperation} from a native {@link Node}
*
* @param g the graph the operation is in
* @param g the graph the operation is in
* @param node the native node
* @return a graph operation with the underlying native node
*/
Expand Down

0 comments on commit b39bd2f

Please sign in to comment.