Skip to content

Commit

Permalink
Add missing requireHandles
Browse files Browse the repository at this point in the history
Signed-off-by: Ryan Nett <[email protected]>
  • Loading branch information
rnett committed Apr 28, 2021
1 parent 478924e commit 04aba60
Showing 1 changed file with 10 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,13 +188,15 @@ Tensor tensor(int outputIdx) {
* Get the number of inputs to the op, not including control inputs.
*/
public int numInputs() {
requireHandle(unsafeNativeHandle);
return TF_OperationNumInputs(getUnsafeNativeHandle());
}

/**
* Get the op's inputs, not including control inputs.
*/
public List<Operand<?>> inputs() {
requireHandle(unsafeNativeHandle);
try (PointerScope scope = new PointerScope()) {
int numInputs = numInputs();
TF_Output handles = new TF_Output(numInputs);
Expand All @@ -219,6 +221,7 @@ public List<Operand<?>> inputs() {
* @param index the output to look for usages of
*/
public int numConsumers(int index) {
requireHandle(unsafeNativeHandle);
try (PointerScope scope = new PointerScope()) {
TF_Output output = new TF_Output().oper(getUnsafeNativeHandle()).index(index);
return TF_OperationOutputNumConsumers(output);
Expand All @@ -231,6 +234,7 @@ public int numConsumers(int index) {
* @param index the output to look for usages of
*/
public Set<GraphOperation> consumers(int index) {
requireHandle(unsafeNativeHandle);
try (PointerScope scope = new PointerScope()) {
TF_Output output = new TF_Output().oper(getUnsafeNativeHandle()).index(index);
int numConsumers = numConsumers(index);
Expand All @@ -253,6 +257,7 @@ public Set<GraphOperation> consumers(int index) {
* Get the number of ops that use any of this op's outputs as an input, not including control dependencies.
*/
public int numConsumers() {
requireHandle(unsafeNativeHandle);
int all = 0;
for (int i = 0; i < numOutputs(); i++) {
all += numConsumers(i);
Expand All @@ -265,6 +270,7 @@ public int numConsumers() {
* Get the ops that use any of this op's outputs as an input, not including control dependencies.
*/
public Set<GraphOperation> consumers() {
requireHandle(unsafeNativeHandle);
Set<GraphOperation> all = new LinkedHashSet<>();
for (int i = 0; i < numOutputs(); i++) {
all.addAll(consumers(i));
Expand All @@ -276,6 +282,7 @@ public Set<GraphOperation> consumers() {
* Get the number of control inputs for this op.
*/
public int numControlInputs() {
requireHandle(unsafeNativeHandle);
try (PointerScope scope = new PointerScope()) {
return TF_OperationNumControlInputs(getUnsafeNativeHandle());
}
Expand All @@ -285,6 +292,7 @@ public int numControlInputs() {
* Get the control inputs of this op.
*/
public Set<GraphOperation> controlInputs() {
requireHandle(unsafeNativeHandle);
try (PointerScope scope = new PointerScope()) {
int numInputs = numControlInputs();
PointerPointer<TF_Operation> handles = new PointerPointer<>(numInputs);
Expand All @@ -305,6 +313,7 @@ public Set<GraphOperation> controlInputs() {
* Get the number of ops with this op as a control dependency.
*/
public int numControlConsumers() {
requireHandle(unsafeNativeHandle);
try (PointerScope scope = new PointerScope()) {
return TF_OperationNumControlOutputs(getUnsafeNativeHandle());
}
Expand All @@ -314,6 +323,7 @@ public int numControlConsumers() {
* Get the ops with this op as a control dependency.
*/
public Set<GraphOperation> controlConsumers() {
requireHandle(unsafeNativeHandle);
try (PointerScope scope = new PointerScope()) {
int numConsumers = numControlConsumers();
PointerPointer<TF_Operation> handles = new PointerPointer<>(numConsumers);
Expand Down

0 comments on commit 04aba60

Please sign in to comment.