diff --git a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GradientAdapterHelpers.java b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GradientAdapterHelpers.java index bd20eab7136..2bfd1f96c6f 100644 --- a/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GradientAdapterHelpers.java +++ b/tensorflow-core/tensorflow-core-api/src/main/java/org/tensorflow/GradientAdapterHelpers.java @@ -1,29 +1,31 @@ /* - Copyright 2021 The TensorFlow Authors. All Rights Reserved. + Copyright 2021 The TensorFlow Authors. All Rights Reserved. - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at - http://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - ============================================================================== - */ +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +============================================================================== +*/ package org.tensorflow; +import static org.tensorflow.internal.c_api.global.tensorflow.OutputsFromTFOutputs; +import static org.tensorflow.internal.c_api.global.tensorflow.TFOutputsFromOutputs; import static org.tensorflow.internal.c_api.global.tensorflow.ToOperation; import java.util.ArrayList; import java.util.List; import org.bytedeco.javacpp.PointerScope; -import org.tensorflow.internal.c_api.NativeOutput; import org.tensorflow.internal.c_api.NativeOutputVector; import org.tensorflow.internal.c_api.Node; +import org.tensorflow.internal.c_api.TF_Output; /** * Helpers for {@link org.tensorflow.op.TypedGradientAdapter} and {@link @@ -34,15 +36,16 @@ public class GradientAdapterHelpers { /** * Convert a array of native outputs to a list of {@link Output}s. * - * @param g the graph the outputs are in + * @param g the graph the outputs are in * @param nativeOutputs the native outputs to convert */ public static List> fromNativeOutputs(Graph g, NativeOutputVector nativeOutputs) { + TF_Output outputs = new TF_Output(nativeOutputs.size()); + TFOutputsFromOutputs(nativeOutputs, outputs); List> gradInputs = new ArrayList<>((int) nativeOutputs.size()); for (int i = 0; i < nativeOutputs.size(); i++) { - NativeOutput output = nativeOutputs.get(i); - gradInputs.add(new Output<>(getGraphOp(g, output.node()), - output.index())); + TF_Output output = outputs.position(i); + gradInputs.add(new Output<>(new GraphOperation(g, output.oper()), output.index())); } return gradInputs; } @@ -50,23 +53,32 @@ public static List> fromNativeOutputs(Graph g, NativeOutputVector nati /** * Put the Java outputs into the array of native outputs, resizing it to the necessary size. * - * @param outputs the outputs to put + * @param outputs the outputs to put * @param nativeOutputs the native array to put the outputs into */ - public static void putToNativeOutputs(List> outputs, - NativeOutputVector nativeOutputs) { + public static void putToNativeOutputs( + List> outputs, NativeOutputVector nativeOutputs) { nativeOutputs.resize(outputs.size()); + + TF_Output tempOutputs = new TF_Output(outputs.size()); for (int i = 0; i < outputs.size(); i++) { Output output = outputs.get(i).asOutput(); - Node node = ((GraphOperation) output.op()).getUnsafeNativeHandle().node(); - nativeOutputs.put(i, new NativeOutput(node, output.index())); + GraphOperation graphOp = (GraphOperation) output.op(); + tempOutputs + .position(i) + .put(new TF_Output().oper(graphOp.getUnsafeNativeHandle()).index(output.index())); + } + + NativeOutputVector temp = OutputsFromTFOutputs(tempOutputs, outputs.size()); + for (int i = 0; i < outputs.size(); i++) { + nativeOutputs.put(i, temp.get(i)); } } /** * Make a {@link GraphOperation} from a native {@link Node} * - * @param g the graph the operation is in + * @param g the graph the operation is in * @param node the native node * @return a graph operation with the underlying native node */