diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java index ea3ef31313e..acbae4dac6b 100644 --- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java +++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java @@ -354,20 +354,20 @@ public final class Ops { public final SparseOps sparse; - public final TpuOps tpu; - public final BitwiseOps bitwise; + public final TpuOps tpu; + public final MathOps math; public final AudioOps audio; public final SignalOps signal; - public final TrainOps train; - public final QuantizationOps quantization; + public final TrainOps train; + private final Scope scope; private Ops(Scope scope) { @@ -385,13 +385,13 @@ private Ops(Scope scope) { random = new RandomOps(this); strings = new StringsOps(this); sparse = new SparseOps(this); - tpu = new TpuOps(this); bitwise = new BitwiseOps(this); + tpu = new TpuOps(this); math = new MathOps(this); audio = new AudioOps(this); signal = new SignalOps(this); - train = new TrainOps(this); quantization = new QuantizationOps(this); + train = new TrainOps(this); } /** diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java index e1482a51a8a..708c7daead6 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java @@ -18,14 +18,7 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.family.TNumber; -/** - * Abstract base class for Activations - * - *

Note: The {@link #tf} attribute must be set prior to invoking the call method. See - * {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}. - * - * @param the data type of the activation - */ +/** Abstract base class for Activations */ public abstract class Activation { /** The TensorFlow Ops */ @@ -41,28 +34,29 @@ protected Activation(Ops tf) { } /** - * Sets the TensorFlow Ops + * Gets the TensorFlow Ops * - * @param tf the TensorFlow Ops + * @return the TensorFlow Ops */ - protected void setTF(Ops tf) { - this.tf = tf; + protected Ops getTF() { + return this.tf; } /** - * Gets the TensorFlow Ops + * Sets the TensorFlow Ops * - * @return the TensorFlow Ops + * @param tf the TensorFlow Ops */ - protected Ops getTF() { - return this.tf; + protected void setTF(Ops tf) { + this.tf = tf; } /** * Gets the calculation operation for the activation. * * @param input the input tensor + * @param the data type of the input and result * @return The operand for the activation */ - public abstract Operand call(Operand input); + public abstract Operand call(Operand input); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java index 2f2f16f2752..00c720c936b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java @@ -44,11 +44,10 @@ * Operand<TFloat32> result = elu.call(input); * * - * @param the data type of the activation * @see Clevert et al, 2016, Fast and Accurate Deep * Network Learning by Exponential Linear Units (ELUs) */ -public class ELU extends Activation { +public class ELU extends Activation { private static final double ALPHA_DEFAULT = 1.0; @@ -76,20 +75,16 @@ public ELU(Ops tf, double alpha) { this.alpha = alpha; } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { - Operand result = tf.nn.elu(input); - if (alpha == 1.0) return result; - else { - Class inputType = input.type(); - Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType)); + Operand result = tf.nn.elu(input); + if (alpha == 1.0) { + return result; + } else { + Class inputType = input.type(); + Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType)); Operand cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), inputType)); return tf.select(cond, result, y); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java index d5fdff36c61..512f96713aa 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java @@ -30,10 +30,8 @@ * Operand<TFloat32> result = exp.call(input); * // result is [0.04978707f, 0.36787945f, 1.f, 2.7182817f, 20.085537f] * - * - * @param the data type of the activation */ -public class Exponential extends Activation { +public class Exponential extends Activation { /** * Creates an Exponential activation. @@ -48,10 +46,12 @@ public Exponential(Ops tf) { * Calculates the Exponential activation. * * @param input the input tensor + * @param the data type of the input and result * @return an Operand for the exponential activation: exp(x). */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { + return tf.math.exp(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/GeLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/GeLU.java new file mode 100644 index 00000000000..abddbd512cf --- /dev/null +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/GeLU.java @@ -0,0 +1,116 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +=======================================================================*/ +package org.tensorflow.framework.activations; + +import org.tensorflow.Operand; +import org.tensorflow.op.Ops; +import org.tensorflow.types.family.TFloating; + +import static org.tensorflow.framework.utils.CastHelper.cast; + +/** + * Applies the Gaussian error linear unit (GELU) activation function. + * + *

Gaussian error linear unit (GELU) computes {@code x * P(X <= x)}, where {@code P(X) ~ N(0, + * 1)}. The (GELU) nonlinearity weights inputs by their value, rather than gates inputs by their + * sign as in ReLU. if approximate is true : + * + *

+ *     0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))
+ * 
+ * + *

or, if approximate is false. + * + *

+ *     x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2))),
+ * 
+ * + * where P(X) ~ N(0, 1). + * + * @see Hendrycks, Dan and Gimpel, Kevin, 2016-2020, + * Gaussian Error Linear Units (GELUs) + */ +public class GeLU extends Activation { + + private final boolean approximate; + + /** + * Creates a e Gaussian error linear unit (GELU) activation, with approximate set to false + * + * @param tf The TensorFlow ops + */ + public GeLU(Ops tf) { + this(tf, false); + } + + /** + * Creates a e Gaussian error linear unit (GELU) activation + * + * @param tf The TensorFlow ops + * @param approximate indicator whether to enable approximation. + */ + public GeLU(Ops tf, boolean approximate) { + super(tf); + this.approximate = approximate; + } + + /** {@inheritDoc} */ + @Override + public Operand call(Operand input) { + + Operand point5 = cast(tf, tf.constant(0.5), input.type()); + Operand one = cast(tf, tf.constant(1.0), input.type()); + + if (approximate) { + /* + coeff = math_ops.cast(0.044715, features.dtype) + return 0.5 * features * ( + 1.0 + math_ops.tanh(0.7978845608028654 * + (features + coeff * math_ops.pow(features, 3)))) + */ + Operand coeff = cast(tf, tf.constant(0.044715), input.type()); + // sqrt(2.0 / PI) + Operand sqrtTwoDivPI = cast(tf, tf.constant(0.7978845608028654), input.type()); + Operand three = cast(tf, tf.constant(3), input.type()); + + return tf.math.mul( + point5, + tf.math.mul( + input, + tf.math.add( + one, + tf.math.tanh( + tf.math.mul( + sqrtTwoDivPI, + tf.math.add( + input, tf.math.mul(coeff, tf.math.pow(input, three)) // mul + ) // add + ) // mul + ) // tanh + ) // add + ) // mul + ); // mul + + } else { + /* + return 0.5 * features * (1.0 + math_ops.erf( + features / math_ops.cast(1.4142135623730951, features.dtype))) + */ + Operand sqrtTwo = cast(tf, tf.constant(1.4142135623730951), input.type()); + return tf.math.mul( + point5, tf.math.mul(input, tf.math.add(one, tf.math.erf(tf.math.div(input, sqrtTwo))))); + } + } +} diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java index 0b7cf573b8e..6c7323a90cb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java @@ -16,7 +16,7 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; +import org.tensorflow.types.family.TNumber; /** * Hard sigmoid activation. @@ -40,10 +40,8 @@ * Operand<TFloat32> result = hardSigmoid.call(input); * // result is [0.f , 0.3f, 0.5f, 0.7f, 1.f] * - * - * @param the data type of the result */ -public class HardSigmoid extends Activation { +public class HardSigmoid extends Activation { /** * Creates Hard sigmoid activation. @@ -54,19 +52,14 @@ public HardSigmoid(Ops tf) { super(tf); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - Class inputType = input.type(); - Operand point2 = tf.dtypes.cast(tf.constant(0.2), inputType); - Operand point5 = tf.dtypes.cast(tf.constant(0.5), inputType); + public Operand call(Operand input) { + Class inputType = input.type(); + Operand point2 = tf.dtypes.cast(tf.constant(0.2), inputType); + Operand point5 = tf.dtypes.cast(tf.constant(0.5), inputType); - Operand x = tf.math.add(tf.math.mul(input, point2), point5); + Operand x = tf.math.add(tf.math.mul(input, point2), point5); return tf.clipByValue( x, tf.dtypes.cast(tf.constant(0), inputType), tf.dtypes.cast(tf.constant(1), inputType)); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java index d907397995d..dcda76db6bf 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java @@ -19,9 +19,9 @@ import org.tensorflow.types.family.TNumber; /** - * Linear activation function (pass-through). + * Linear activation function (pass-through). * - *

The linear activation returns its input. It is also known as the Identity activation function.

+ *

The linear activation returns its input. It is also known as the Identity activation function. * *

For example: * @@ -33,7 +33,7 @@ * // result is [-3.0f,-1.0f, 0.0f,1.0f,3.0f] * */ -public class Linear extends Activation { +public class Linear extends Activation { /** * Creates a linear activation. @@ -46,7 +46,7 @@ public Linear(Ops tf) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return input; } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java index aef6ebf2992..409015a1203 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java @@ -55,10 +55,8 @@ * result = relu.call(input); * // result is [-0.f, -0.f, 0.f, 0.f, 10.f] * - * - * @param the data type of the result */ -public class ReLU extends Activation { +public class ReLU extends Activation { public static final float ALPHA_DEFAULT = 0.0f; public static final float MAX_VALUE_DEFAULT = Float.NaN; @@ -96,11 +94,11 @@ public ReLU(Ops tf, float alpha, float maxValue, float threshold) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { - Class inputType = input.type(); + public Operand call(Operand input) { + Class inputType = input.type(); boolean clipMax = !Float.isNaN(maxValue); - Operand negativePart = null; + Operand negativePart = null; if (alpha != 0) { if (Float.isNaN(maxValue) && threshold == 0) { return tf.nn.leakyRelu(input, LeakyRelu.alpha(alpha)); @@ -114,7 +112,7 @@ public Operand call(Operand input) { } } - Operand lInput; + Operand lInput; if (threshold != 0) { // computes input for input > threshold else 0 Greater greater = tf.math.greater(input, tf.dtypes.cast(tf.constant(threshold), inputType)); @@ -127,8 +125,8 @@ public Operand call(Operand input) { lInput = tf.nn.relu(input); } if (clipMax) { - Operand lmaxValue = tf.dtypes.cast(tf.constant(maxValue), inputType); - Operand zero = tf.dtypes.cast(tf.constant(0), inputType); + Operand lmaxValue = tf.dtypes.cast(tf.constant(maxValue), inputType); + Operand zero = tf.dtypes.cast(tf.constant(0), inputType); lInput = tf.clipByValue(lInput, zero, lmaxValue); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java index f24731049fb..7955144abbb 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java @@ -42,10 +42,9 @@ *

Notes: To be used together with the {@link * org.tensorflow.framework.initializers.LeCun} initializer with Normal Distribution. * - * @param the data type of the activation * @see Klambauer et al., 2017 */ -public class SELU extends Activation { +public class SELU extends Activation { /** * Creates a Scaled Exponential Linear Unit (SELU) activation. @@ -56,14 +55,9 @@ public SELU(Ops tf) { super(tf); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.nn.selu(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java index 5d507b38483..a89b6119d02 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java @@ -38,10 +38,8 @@ * // result is [2.0611537e-09f, 2.6894143e-01f, * // 5.0000000e-01f,7.3105860e-01f, 1.f] * - * - * @param the data type of the activation */ -public class Sigmoid extends Activation { +public class Sigmoid extends Activation { /** * Creates a Sigmoid activation. @@ -52,14 +50,9 @@ public Sigmoid(Ops tf) { super(tf); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.math.sigmoid(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java index 154e1ecc84a..309cdd68b8b 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java @@ -35,10 +35,8 @@ *

The softmax of each vector x is computed as: exp(x) / tf.sum(exp(x)). * *

The input values in are the log-odds of the resulting probability. - * - * @param the data type of the activation */ -public class Softmax extends Activation { +public class Softmax extends Activation { private static final int AXIS_DEFAULT = -1; @@ -65,23 +63,18 @@ public Softmax(Ops tf, int axis) { this.axis = axis; } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { Shape shape = input.shape(); int numDimensions = shape.numDimensions(); if (numDimensions == 2) { return tf.nn.softmax(input); } else { - Operand e = + Operand e = tf.math.exp( tf.math.sub(input, tf.reduceMax(input, tf.constant(axis), ReduceMax.keepDims(true)))); - Operand s = tf.reduceSum(e, tf.constant(axis), ReduceSum.keepDims(true)); + Operand s = tf.reduceSum(e, tf.constant(axis), ReduceSum.keepDims(true)); return tf.math.div(e, s); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java index 65a183ea047..0eb703aad9f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java @@ -32,7 +32,7 @@ * // 1.3132616e+00f, 2.0000000e+01f] * */ -public class Softplus extends Activation { +public class Softplus extends Activation { /** * Creates a Softplus activation function. @@ -43,14 +43,9 @@ public Softplus(Ops tf) { super(tf); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.math.softplus(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java index 1f691e71862..0a7754258df 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java @@ -30,10 +30,8 @@ * Operand<TFloat32> result = softsign.call(input); * // result is [-0.5f, 0.f, 0.5f] * - * - * @param the data type of the activation */ -public class Softsign extends Activation { +public class Softsign extends Activation { /** * Creates a Softsign activation. @@ -44,14 +42,9 @@ public Softsign(Ops tf) { super(tf); } - /** - * Gets the calculation operation for the activation. - * - * @param input the input tensor - * @return The operand for the activation - */ + /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.nn.softsign(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java index d9f73a422d5..eb43a21f285 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java @@ -37,10 +37,9 @@ * * * - * @param the data type of the activation * @see Ramachandran et al., 2017 */ -public class Swish extends Activation { +public class Swish extends Activation { /** * Creates a Swish activation, swish(x) = x * sigmoid(x). @@ -57,7 +56,7 @@ public Swish(Ops tf) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { // TODO Python Keras returns a "grad", which is an optimization not implemented in Java. return tf.math.mul(input, tf.math.sigmoid(input)); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java index 4fe02eed048..145561b9129 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java @@ -30,10 +30,8 @@ * Operand<TFloat32> result = tanh.call(input); * // result = [-0.9950547f, -0.7615942f, 0.f, 0.7615942f, 0.9950547f] * - * - * @param the data type of the activation */ -public class Tanh extends Activation { +public class Tanh extends Activation { /** * Creates a Hyperbolic tangent activation. @@ -46,7 +44,7 @@ public Tanh(Ops tf) { /** {@inheritDoc} */ @Override - public Operand call(Operand input) { + public Operand call(Operand input) { return tf.math.tanh(input); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java index d3094b5e9e9..739f9f55c55 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java @@ -41,7 +41,9 @@ public Constraint(Ops tf) { * Applies the constraint against the provided weights * * @param weights the weights + * @param the data the weights and result * @return the constrained weights + * */ public abstract Operand call(Operand weights); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java index 4a2df86d74b..3f2ebe58cb4 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java @@ -16,11 +16,11 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; -import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; -import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a constant value. * @@ -33,76 +33,39 @@ * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * - * @param The Type for the call operation + *

Only scalar values are allowed. The constant value provided must be convertible to the data + * type requested when calling the initializer. */ -public class Constant extends BaseInitializer { - - private final double doubleValue; - private final long longValue; - private final boolean booleanValue; - private final ValueType valueType; +public class Constant extends BaseInitializer { - /** - * Creates an Initializer that generates tensors with a constant value. - * - * @param tf the TensorFlow Ops - * @param value a long value used for the constant. - */ - public Constant(Ops tf, long value) { - super(tf); - longValue = value; - doubleValue = 0; - booleanValue = false; - valueType = ValueType.LONG; - } + private final Operand value; /** * Creates an Initializer that generates tensors with a constant value. * * @param tf the TensorFlow Ops - * @param value a double value used for the constant. + * @param value the value used for the constant. + * @throws IllegalArgumentException if value is not a scalar. */ - public Constant(Ops tf, double value) { + public Constant(Ops tf, Operand value) { super(tf); - doubleValue = value; - longValue = 0; - booleanValue = false; - valueType = ValueType.DOUBLE; + if (!value.shape().isScalar()) { + throw new IllegalArgumentException("value must be scalar, got shape : " + value.shape()); + } + this.value = value; } /** - * Creates an Initializer that generates tensors with a constant value. + * Generates the operation used to perform the initialization. * - * @param tf the TensorFlow Ops - * @param value a boolean value used for the constant. + * @param dims the shape dimensions + * @param type the data type of tensor + * @param The data Type for initializer operation + * @return An operand for the initialization. */ - public Constant(Ops tf, boolean value) { - super(tf); - booleanValue = value; - doubleValue = 0; - longValue = 0; - valueType = ValueType.BOOLEAN; - } - - /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { - throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); - } - switch (valueType) { - case LONG: - return tf.fill(dims, tf.dtypes.cast(tf.constant(longValue), type)); - case DOUBLE: - return tf.fill(dims, tf.dtypes.cast(tf.constant(doubleValue), type)); - default: - return tf.fill(dims, tf.dtypes.cast(tf.constant(booleanValue), type)); - } - } + public Operand call(Operand dims, Class type) { - private enum ValueType { - LONG, - DOUBLE, - BOOLEAN + return tf.fill(dims, cast(tf, value, type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java index 894bd073758..5a3c291785f 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java @@ -16,7 +16,6 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; /** * The Glorot initializer, also called Xavier initializer. @@ -58,16 +57,17 @@ * * *

NOTE: + * *

For a GlorotNormal equivalent initializer, use {@link * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. + * *

For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * for the distribution parameter. * - * @param The TType for the call operation * @see VarianceScaling.Distribution * @see Glorot et al., 2010 */ -public class Glorot extends VarianceScaling { +public class Glorot extends VarianceScaling { public static final double SCALE = 1.0; @@ -77,7 +77,7 @@ public class Glorot extends VarianceScaling { * @param tf the TensorFlow Ops * @param distribution The distribution type for the Glorot initializer. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. * @see VarianceScaling.Distribution */ public Glorot(Ops tf, Distribution distribution, long seed) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java index 3a91b72b0d0..ac64e449265 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; /** * He initializer. @@ -53,17 +52,18 @@ * * *

NOTE: + * *

For an HeNormal equivalent initializer, use {@link * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. - *

For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} - * for the distribution parameter. * - * @param The TType for the call operation + *

For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} for + * the distribution parameter. + * * @see He * et al., 2015 */ -public class He extends VarianceScaling { +public class He extends VarianceScaling { public static final double SCALE = 2.0; @@ -73,7 +73,7 @@ public class He extends VarianceScaling { * @param tf the TensorFlow Ops * @param distribution The distribution type for the He initializer. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. * @see VarianceScaling.Distribution */ public He(Ops tf, Distribution distribution, long seed) { diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java index f672c9f1e85..2b1ccf00bae 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java @@ -21,6 +21,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates the identity matrix. * @@ -34,10 +36,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class Identity extends BaseInitializer { +public class Identity extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; private final double gain; @@ -65,7 +65,7 @@ public Identity(Ops tf, double gain) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Operand dims, Class type) { Shape shape = ShapeUtils.toShape(tf.scope(), dims); if (shape.numDimensions() != 2) { throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions()); @@ -74,10 +74,10 @@ public Operand call(Operand dims, Class type) { long diagSize = Math.min(shape.size(0), shape.size(1)); Shape diagShape = Shape.of(diagSize); - Operand op; - Operand zero = tf.dtypes.cast(tf.constant(0), type); - Operand diagOnes = - tf.fill(tf.constant(diagShape.asArray()), tf.dtypes.cast(tf.constant(1.0), type)); + Operand op; + Operand zero = cast(tf, tf.constant(0), type); + Operand diagOnes = + tf.fill(tf.constant(diagShape.asArray()), cast(tf, tf.constant(1.0), type)); if (isSquare) { op = tf.linalg.matrixDiag( @@ -87,10 +87,10 @@ public Operand call(Operand dims, Class type) { tf.constant((int) shape.size(1)), zero); } else { - Operand zeroMatrix = tf.zeros(dims, type); + Operand zeroMatrix = tf.zeros(dims, type); op = tf.linalg.matrixSetDiag(zeroMatrix, diagOnes, tf.constant(0)); } - return tf.math.mul(op, tf.dtypes.cast(tf.constant(gain), type)); + return tf.math.mul(op, cast(tf, tf.constant(gain), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java index 4beb218783b..2f30bc5e99d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java @@ -18,19 +18,16 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TType; -/** - * An interface for Initializers - * - * @param The data Type for initializer operation - */ +/** An interface for Initializers */ public interface Initializer { /** * Generates the operation used to perform the initialization. * * @param dims the shape dimensions - * @param type the type of tensor + * @param type the data type of tensor + * @param The data Type for initializer operation * @return An operand for the initialization. */ - Operand call(Operand dims, Class type); + Operand call(Operand dims, Class type); } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java index 38e68ef688b..b82f40918c0 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java @@ -15,7 +15,6 @@ package org.tensorflow.framework.initializers; import org.tensorflow.op.Ops; -import org.tensorflow.types.family.TFloating; /** * LeCun normal initializer. @@ -27,7 +26,7 @@ * stddev = sqrt(1 / fanIn) where fanIn is the number of input units in the * weight tensor. * - *

If the distribution is UNIFORM, itraws samples from a uniform distribution within + *

If the distribution is UNIFORM, it draws samples from a uniform distribution within * [-limit, limit], where limit = Math.sqrt(3 / fanIn) (fanIn is * the number of input units in the weight tensor) * @@ -59,14 +58,14 @@ * *

NOTE: * * - *

For a LeCunNormal equivalent initializer, use {@link VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. * + *

For a LeCunNormal equivalent initializer, use {@link + * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. * * *

For a LeCunUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} * * for the distribution parameter. * * *

* - * @param The TType for the call operation * @see Self-Normalizing * Neural Networks, Klambauer et al., 2017 @@ -74,7 +73,7 @@ * al., 1998 * @see VarianceScaling.Distribution */ -public class LeCun extends VarianceScaling { +public class LeCun extends VarianceScaling { /** * Creates a LeCunNormal Initializer @@ -82,7 +81,7 @@ public class LeCun extends VarianceScaling { * @param tf the TensorFlow Ops * @param distribution The distribution type for the Glorot initializer. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public LeCun(Ops tf, Distribution distribution, long seed) { super(tf, 1.0, Mode.FAN_IN, distribution, seed); diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java index b8eb0c418e9..5cf8c7a8033 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors initialized to 1. * @@ -32,10 +34,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class Ones extends BaseInitializer { +public class Ones extends BaseInitializer { /** * Creates an Initializer that sets all values to one. @@ -55,12 +55,22 @@ public Ones(Ops tf) { super(tf); } - /** {@inheritDoc} */ + /** + * Generates the operation used to perform the initialization. + * + * @param dims the shape dimensions + * @param type the data type of tensor + * @param The data Type for initializer operation + * @return An operand for the initialization. + * @throws IllegalArgumentException if the data type is not a TNumber or TBool + */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Operand dims, Class type) { if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { - throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName()); + throw new IllegalArgumentException( + "Tensor type must be numeric or boolean: " + type.getSimpleName()); } - return tf.fill(dims, tf.dtypes.cast(tf.constant(1.0), type)); + + return tf.fill(dims, cast(tf, tf.constant(1), type)); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java index a5b466e118e..14f1049d038 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java @@ -23,6 +23,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates an orthogonal matrix. * @@ -44,10 +46,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class Orthogonal extends BaseInitializer { +public class Orthogonal extends BaseInitializer { public static final double GAIN_DEFAULT = 1.0; @@ -59,7 +59,7 @@ public class Orthogonal extends BaseInitializer { * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public Orthogonal(Ops tf, long seed) { this(tf, GAIN_DEFAULT, seed); @@ -71,7 +71,7 @@ public Orthogonal(Ops tf, long seed) { * @param tf the TensorFlow Ops * @param gain the gain to be applied to the Matrix. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public Orthogonal(Ops tf, double gain, long seed) { super(tf); @@ -81,7 +81,8 @@ public Orthogonal(Ops tf, double gain, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Operand dims, Class type) { + Shape dimsShape = ShapeUtils.toShape(tf.scope(), dims); if (dimsShape.numDimensions() < 2) { throw new IllegalArgumentException( @@ -94,17 +95,20 @@ public Operand call(Operand dims, Class type) { long numCols = dimsShape.size(i); Shape flatShape = Shape.of(Math.max(numRows, numCols), Math.min(numRows, numCols)); long[] seeds = {seed, 0}; - Operand op = + Operand op = tf.random.statelessRandomNormal(tf.constant(flatShape), tf.constant(seeds), type); + Qr.Options qrOptions = Qr.fullMatrices(false); - Qr qrOp = tf.linalg.qr(op, qrOptions); - Output qo = qrOp.q(); - Output ro = qrOp.r(); - Operand diagOp = - tf.linalg.matrixDiagPart(ro, tf.constant(0), tf.dtypes.cast(tf.constant(0), type)); - Operand qop = tf.math.mul(qo, tf.math.sign(diagOp)); - if (numRows < numCols) qop = tf.linalg.transpose(qop, null); + Qr qrOp = tf.linalg.qr(op, qrOptions); + Output qo = qrOp.q(); + Output ro = qrOp.r(); + Operand diagOp = + tf.linalg.matrixDiagPart(ro, tf.constant(0), cast(tf, tf.constant(0), op.type())); + Operand qop = tf.math.mul(qo, tf.math.sign(diagOp)); + if (numRows < numCols) { + qop = tf.linalg.transpose(qop, null); + } - return tf.math.mul(qop, tf.dtypes.cast(tf.constant(this.gain), type)); + return tf.math.mul(qop, cast(tf, tf.constant(this.gain), op.type())); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java index 38ab194a56b..6b90ed3985d 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomNormal.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a normal distribution. * @@ -31,10 +33,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class RandomNormal extends BaseInitializer { +public class RandomNormal extends BaseInitializer { public static final double MEAN_DEFAULT = 0.0; public static final double STDDEV_DEFAULT = 1.0; @@ -49,7 +49,7 @@ public class RandomNormal extends BaseInitializer { * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public RandomNormal(Ops tf, long seed) { this(tf, MEAN_DEFAULT, STDDEV_DEFAULT, seed); @@ -61,7 +61,7 @@ public RandomNormal(Ops tf, long seed) { * @param tf the TensorFlow Ops * @param mean Mean of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public RandomNormal(Ops tf, double mean, long seed) { this(tf, mean, STDDEV_DEFAULT, seed); @@ -74,10 +74,15 @@ public RandomNormal(Ops tf, double mean, long seed) { * @param mean Mean of the random values to generate. * @param stddev Standard deviation of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. + * @throws IllegalArgumentException if standard deviation is less than 0. */ public RandomNormal(Ops tf, double mean, double stddev, long seed) { super(tf); + if (stddev < 0) { + throw new IllegalArgumentException( + "Standard deviation (stddev) cannot be less than 0, got " + stddev); + } this.mean = mean; this.stddev = stddev; this.seed = seed; @@ -85,10 +90,10 @@ public RandomNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Operand dims, Class type) { long[] seeds = {seed, 0}; - Operand distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); - Operand op = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.stddev), type)); - return tf.math.add(op, tf.dtypes.cast(tf.constant(mean), type)); + Operand distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); + Operand op = tf.math.mul(distOp, cast(tf, tf.constant(this.stddev), type)); + return tf.math.add(op, cast(tf, tf.constant(mean), distOp.type())); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java index 787af15f709..c0ebce135e9 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/RandomUniform.java @@ -21,6 +21,8 @@ import org.tensorflow.types.family.TIntegral; import org.tensorflow.types.family.TNumber; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates tensors with a uniform distribution. * @@ -33,10 +35,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class RandomUniform extends BaseInitializer { +public class RandomUniform extends BaseInitializer { public static final double MINVAL_DEFAULT = -0.05; public static final double MAXVAL_DEFAULT = 0.05; @@ -46,12 +46,12 @@ public class RandomUniform extends BaseInitializer { private final long seed; /** - * Creates a RandomUniform initializer using {@link #MINVAL_DEFAULT} for the minval and - * {@link #MAXVAL_DEFAULT} for the maxval + * Creates a RandomUniform initializer using {@link #MINVAL_DEFAULT} for the minval and {@link + * #MAXVAL_DEFAULT} for the maxval * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public RandomUniform(Ops tf, long seed) { this(tf, MINVAL_DEFAULT, MAXVAL_DEFAULT, seed); @@ -64,7 +64,7 @@ public RandomUniform(Ops tf, long seed) { * @param minval Lower bound of the range of random values to generate (inclusive). * @param maxval Upper bound of the range of random values to generate (exclusive). * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public RandomUniform(Ops tf, double minval, double maxval, long seed) { super(tf); @@ -75,26 +75,27 @@ public RandomUniform(Ops tf, double minval, double maxval, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - Operand distOp; + public Operand call(Operand dims, Class type) { + Operand distOp; if (TIntegral.class.isAssignableFrom(type)) { RandomUniformInt.Options options = RandomUniformInt.seed(this.seed); distOp = tf.random.randomUniformInt( dims, - tf.dtypes.cast(tf.constant(this.minval), type), - tf.dtypes.cast(tf.constant(this.maxval), type), + cast(tf, tf.constant(this.minval), type), + cast(tf, tf.constant(this.maxval), type), options); } else { long[] seeds = {seed, 0}; distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); if (this.minval == 0) { if (this.maxval != 1.0) { - distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval), type)); + distOp = tf.math.mul(distOp, cast(tf, tf.constant(this.maxval), distOp.type())); } } else { - distOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(this.maxval - this.minval), type)); - distOp = tf.math.add(distOp, tf.dtypes.cast(tf.constant(this.minval), type)); + distOp = + tf.math.mul(distOp, cast(tf, tf.constant(this.maxval - this.minval), distOp.type())); + distOp = tf.math.add(distOp, cast(tf, tf.constant(this.minval), distOp.type())); } } return distOp; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java index d3cfec26338..eaf94663993 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/TruncatedNormal.java @@ -19,6 +19,8 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer that generates a truncated normal distribution. * @@ -31,10 +33,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class TruncatedNormal extends BaseInitializer { +public class TruncatedNormal extends BaseInitializer { public static final double MEAN_DEFAULT = 0.0; public static final double STDDEV_DEFAULT = 0.05; @@ -49,7 +49,7 @@ public class TruncatedNormal extends BaseInitializer { * * @param tf the TensorFlow Ops * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public TruncatedNormal(Ops tf, long seed) { this(tf, MEAN_DEFAULT, STDDEV_DEFAULT, seed); @@ -62,7 +62,7 @@ public TruncatedNormal(Ops tf, long seed) { * @param mean Mean of the random values to generate. * @param stddev Standard deviation of the random values to generate. * @param seed the seed for random number generation. An initializer created with a given seed - * will always produce the same random tensor for a given shape and dtype. + * will always produce the same random tensor for a given shape and data type. */ public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { super(tf); @@ -73,11 +73,11 @@ public TruncatedNormal(Ops tf, double mean, double stddev, long seed) { /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { - long[] seeds = {seed,0}; - Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); + public Operand call(Operand dims, Class type) { + long[] seeds = {seed, 0}; + Operand distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); return tf.math.add( - tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)), - tf.dtypes.cast(tf.constant(mean), type)); + tf.math.mul(distOp, cast(tf, tf.constant(stddev), distOp.type())), + cast(tf, tf.constant(mean), distOp.type())); } } diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java index 5d951450505..299d719c7c3 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/VarianceScaling.java @@ -21,11 +21,13 @@ import org.tensorflow.types.TInt64; import org.tensorflow.types.family.TFloating; +import static org.tensorflow.framework.utils.CastHelper.cast; + /** * Initializer capable of adapting its scale to the shape of weights tensors. * - *

With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from - * a truncated/untruncated normal distribution with a mean of zero and a standard deviation (after + *

With distribution=TRUNCATED_NORMAL or NORMAL, samples are drawn from a + * truncated/untruncated normal distribution with a mean of zero and a standard deviation (after * truncation, if used) stddev = Math.sqrt(scale / n), where n is: * *

    @@ -49,11 +51,10 @@ * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * * - * @param The TType for the call operation * @see VarianceScaling.Mode * @see VarianceScaling.Distribution */ -public class VarianceScaling extends BaseInitializer { +public class VarianceScaling extends BaseInitializer { public static final double SCALE_DEFAULT = 1.0; public static final Mode MODE_DEFAULT = Mode.FAN_IN; @@ -64,7 +65,6 @@ public class VarianceScaling extends BaseInitializer { private final Distribution distribution; private final long seed; - /** * Creates a VarianceScaling Initializer * @@ -97,7 +97,7 @@ public VarianceScaling(Ops tf, double scale, Mode mode, Distribution distributio /** {@inheritDoc} */ @Override - public Operand call(Operand dims, Class type) { + public Operand call(Operand dims, Class type) { Shape shape = ShapeUtils.toShape(this.tf.scope(), dims); double lscale = this.scale; double[] fans /* fanIn, fanOut */ = computeFans(shape); @@ -112,25 +112,25 @@ public Operand call(Operand dims, Class type) { lscale /= Math.max(1., (fans[0] + fans[1]) / 2.); break; } - Operand distOp; - Operand mulOp = null; + Operand distOp; + Operand mulOp = null; double stddev; long[] seeds = {seed, 0}; switch (distribution) { case TRUNCATED_NORMAL: distOp = tf.random.statelessTruncatedNormal(dims, tf.constant(seeds), type); stddev = Math.sqrt(lscale) / .87962566103423978; - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; case NORMAL: distOp = tf.random.statelessRandomNormal(dims, tf.constant(seeds), type); stddev = Math.sqrt(lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; case UNIFORM: distOp = tf.random.statelessRandomUniform(dims, tf.constant(seeds), type); stddev = Math.sqrt(3.0 * lscale); - mulOp = tf.math.mul(distOp, tf.dtypes.cast(tf.constant(stddev), type)); + mulOp = tf.math.mul(distOp, cast(tf, tf.constant(stddev), type)); break; } return mulOp; diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java index 4298493ac44..b6487dc10cd 100644 --- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java +++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Zeros.java @@ -16,7 +16,9 @@ import org.tensorflow.Operand; import org.tensorflow.op.Ops; +import org.tensorflow.types.TBool; import org.tensorflow.types.TInt64; +import org.tensorflow.types.family.TNumber; import org.tensorflow.types.family.TType; /** @@ -30,10 +32,8 @@ * Operand<TFloat32> values = * initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class); * - * - * @param The TType for the call operation */ -public class Zeros extends BaseInitializer { +public class Zeros extends BaseInitializer { /** * Creates an Initializer that sets all values to one. @@ -44,8 +44,21 @@ public Zeros(Ops tf) { super(tf); } + /** + * Generates the operation used to perform the initialization. + * + * @param dims the shape dimensions + * @param type the data type of tensor + * @param The data Type for initializer operation + * @return An operand for the initialization. + * @throws IllegalArgumentException if the data type is not a TNumber or TBool + */ @Override - public Operand call(Operand dims, Class dtype) { - return tf.zeros(dims, dtype); + public Operand call(Operand dims, Class type) { + if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) { + throw new IllegalArgumentException( + "Tensor type must be numeric or boolean: " + type.getSimpleName()); + } + return tf.zeros(dims, type); } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java index 914b94dfada..22604a02c08 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ELUTest.java @@ -21,8 +21,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class ELUTest { @@ -42,8 +40,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - - /** Test of ELU call method */ @Test public void testCallFloat() { @@ -52,7 +48,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf); + ELU instance = new ELU(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -66,7 +62,7 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf); + ELU instance = new ELU(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -80,7 +76,7 @@ public void testAlpha() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ELU instance = new ELU<>(tf, 2.0f); + ELU instance = new ELU(tf, 2.0f); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java index 1157c582168..e8758148c21 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ExponentialTest.java @@ -21,8 +21,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class ExponentialTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -41,8 +39,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - - /** Test of Exponential call method. */ @Test public void testCallFloat() { @@ -60,7 +56,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Exponential instance = new Exponential<>(tf); + Exponential instance = new Exponential(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -78,7 +74,7 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Exponential instance = new Exponential<>(tf); + Exponential instance = new Exponential(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/GeLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/GeLUTest.java new file mode 100644 index 00000000000..e2c97e18bbd --- /dev/null +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/GeLUTest.java @@ -0,0 +1,93 @@ +package org.tensorflow.framework.activations; + +import org.junit.jupiter.api.Test; +import org.tensorflow.Operand; +import org.tensorflow.framework.utils.TestSession; +import org.tensorflow.op.Ops; +import org.tensorflow.types.TFloat32; +import org.tensorflow.types.TFloat64; + +class GeLUTest { + + private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; + + /** Test of GeLU call method */ + @Test + public void testCallFloat() { + float[][] input = { + {0.22805803f, 0.60407318f, 0.91519962f, 0.35643331f, 0.28702669f}, + {0.11558246f, 0.57658853f, 0.47569648f, 0.02271072f, 0.24709974f} + }; + float[][] expected = { + {0.13459972f, 0.43922312f, 0.75042395f, 0.22784713f, 0.17593417f}, + {0.06310898f, 0.41392788f, 0.32483157f, 0.01156111f, 0.14766297f} + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + GeLU instance = new GeLU(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of GeLU call method */ + @Test + public void testCallDouble() { + double[][] input = { + {0.22805803, 0.60407318, 0.91519962, 0.35643331, 0.28702669}, + {0.11558246, 0.57658853, 0.47569648, 0.02271072, 0.24709974} + }; + double[][] expected = { + {0.13459972, 0.43922312, 0.75042395, 0.22784713, 0.17593417}, + {0.06310898, 0.41392788, 0.32483157, 0.01156111, 0.14766297} + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + GeLU instance = new GeLU(tf); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of GeLU call method */ + @Test + public void testCallFloatApproximate() { + float[][] input = { + {0.22805803f, 0.60407318f, 0.91519962f, 0.35643331f, 0.28702669f}, + {0.11558246f, 0.57658853f, 0.47569648f, 0.02271072f, 0.24709974f} + }; + float[][] expected = { + {0.13459886f, 0.43918941f, 0.75030122f, 0.22784227f, 0.17593207f}, + {0.06310892f, 0.41389921f, 0.32481722f, 0.01156111f, 0.14766179f} + }; + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + GeLU instance = new GeLU(tf, true); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } + + /** Test of GeLU call method */ + @Test + public void testCallDoubleApproximate() { + double[][] input = { + {0.22805803, 0.60407318, 0.91519962, 0.35643331, 0.28702669}, + {0.11558246, 0.57658853, 0.47569648, 0.02271072, 0.24709974} + }; + double[][] expected = { + {0.13459886, 0.43918941, 0.75030122, 0.22784227, 0.17593207}, + {0.06310892, 0.41389921, 0.32481722, 0.01156111, 0.14766179} + }; + // for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(TestSession.Mode.GRAPH)) { + Ops tf = session.getTF(); + GeLU instance = new GeLU(tf, true); + Operand result = instance.call(tf.constant(input)); + session.evaluate(tf.constant(expected), result); + } + } +} diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java index 35f57c47f66..02b8e565d66 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/HardSigmoidTest.java @@ -20,8 +20,7 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; - -import static org.junit.jupiter.api.Assertions.assertThrows; +import org.tensorflow.types.TInt32; /** @author Jim Clarke */ public class HardSigmoidTest { @@ -41,8 +40,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - - /** Test of HardSigmoid call method. */ @Test public void testCallFloat() { @@ -51,7 +48,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - HardSigmoid instance = new HardSigmoid<>(tf); + HardSigmoid instance = new HardSigmoid(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -65,9 +62,24 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - HardSigmoid instance = new HardSigmoid<>(tf); + HardSigmoid instance = new HardSigmoid(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } } + + /** Test of HardSigmoid call method. */ + @Test + public void testCallInt() { + int[] input = {-3, -1, 0, 1, 3}; + + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + HardSigmoid instance = new HardSigmoid(tf); + Operand result = instance.call(tf.constant(input)); + int[] expected = {0, 0, 0, 0, 0}; + session.evaluate(expected, result); + } + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java index 7974035c680..4e82420c922 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/LinearTest.java @@ -48,7 +48,7 @@ public void testCallInt() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); + Linear instance = new Linear(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -62,7 +62,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); + Linear instance = new Linear(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -76,7 +76,7 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Linear instance = new Linear<>(tf); + Linear instance = new Linear(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java index a0aa2c4b453..3eb32acff86 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/ReLUTest.java @@ -46,7 +46,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); + ReLU instance = new ReLU(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } @@ -60,7 +60,7 @@ public void testCallInt() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); + ReLU instance = new ReLU(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } @@ -74,7 +74,7 @@ public void testCallLong() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); + ReLU instance = new ReLU(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } @@ -88,7 +88,7 @@ public void testCallFloat16() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); + ReLU instance = new ReLU(tf); Operand result = instance.call(tf.dtypes.cast(tf.constant(input), TFloat16.class)); session.evaluate(tf.dtypes.cast(tf.constant(expected), TFloat16.class), result); @@ -103,7 +103,7 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf); + ReLU instance = new ReLU(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } @@ -112,11 +112,11 @@ public void testCallDouble() { @Test public void testAlpha() { double[] input = {-10., -5., 0.0, 5., 10.}; - double[] expected = {-5. , -2.5, 0., 5., 10.}; + double[] expected = {-5., -2.5, 0., 5., 10.}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, 0.5f, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); + ReLU instance = new ReLU(tf, 0.5f, ReLU.MAX_VALUE_DEFAULT, ReLU.THRESHOLD_DEFAULT); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } @@ -129,7 +129,7 @@ public void testMaxValue() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, 5, ReLU.THRESHOLD_DEFAULT); + ReLU instance = new ReLU(tf, ReLU.ALPHA_DEFAULT, 5, ReLU.THRESHOLD_DEFAULT); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } @@ -138,11 +138,11 @@ public void testMaxValue() { @Test public void testThreshold() { double[] input = {-10., -5., 0.0, 5., 10.}; - double[] expected = {-0., -0., 0., 0., 10.}; + double[] expected = {-0., -0., 0., 0., 10.}; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - ReLU instance = new ReLU<>(tf, ReLU.ALPHA_DEFAULT, ReLU.MAX_VALUE_DEFAULT, 5.0f); + ReLU instance = new ReLU(tf, ReLU.ALPHA_DEFAULT, ReLU.MAX_VALUE_DEFAULT, 5.0f); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java index 8bad6f1f066..65d1a2f0135 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SELUTest.java @@ -21,8 +21,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SELUTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -41,8 +39,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - - /** Test of SELU call method */ @Test public void testCallFloat() { @@ -53,7 +49,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SELU instance = new SELU<>(tf); + SELU instance = new SELU(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -71,7 +67,7 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - SELU instance = new SELU<>(tf); + SELU instance = new SELU(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java index 9dca622c3ec..6177d142794 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SigmoidTest.java @@ -21,8 +21,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SigmoidTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -41,7 +39,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - /** Test of Sigmoid call method */ @Test public void testCallFloat() { @@ -59,7 +56,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Sigmoid instance = new Sigmoid<>(tf); + Sigmoid instance = new Sigmoid(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -77,7 +74,7 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Sigmoid instance = new Sigmoid<>(tf); + Sigmoid instance = new Sigmoid(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java index 05ec3a4f716..749b8f34602 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftmaxTest.java @@ -21,8 +21,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SoftmaxTest { @@ -42,7 +40,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - /** Test of Softmax method, of class Activations. */ @Test public void testSoftmaxOpsOperandFloat() { @@ -54,7 +51,7 @@ public void testSoftmaxOpsOperandFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); + Softmax instance = new Softmax(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } @@ -71,7 +68,7 @@ public void testSoftmaxOpsOperandDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); + Softmax instance = new Softmax(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } @@ -88,7 +85,7 @@ public void testSoftmaxOpsOperandDoubleNegative() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); + Softmax instance = new Softmax(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } @@ -99,13 +96,13 @@ public void testSoftmaxOpsOperandDoubleNegative() { public void testSoftmax1D() { double[] input = {1, -2, 3, -4, -5, 6, 7, 8}; double[] expected = { - 6.0352829e-04, 3.0047902e-05, 4.4595040e-03, 4.0665414e-06, - 1.4959969e-06, 8.9571528e-02, 2.4348068e-01, 6.6184908e-01 + 6.0352829e-04, 3.0047902e-05, 4.4595040e-03, 4.0665414e-06, + 1.4959969e-06, 8.9571528e-02, 2.4348068e-01, 6.6184908e-01 }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); + Softmax instance = new Softmax(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } @@ -116,13 +113,13 @@ public void testSoftmax1D() { public void testSoftmax3D() { double[][][] input = {{{1, -2}, {3, -4}}, {{-5, 6}, {-7, 8}}}; double[][][] expected = { - {{9.5257413e-01, 4.7425874e-02}, {9.9908900e-01, 9.1105123e-04}}, - {{1.6701422e-05, 9.9998331e-01}, {3.0590220e-07, 9.9999964e-01}} + {{9.5257413e-01, 4.7425874e-02}, {9.9908900e-01, 9.1105123e-04}}, + {{1.6701422e-05, 9.9998331e-01}, {3.0590220e-07, 9.9999964e-01}} }; for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softmax instance = new Softmax<>(tf); + Softmax instance = new Softmax(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(tf.constant(expected), result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java index a17f2650d62..58a7ddcd3a8 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftplusTest.java @@ -50,7 +50,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softplus instance = new Softplus<>(tf); + Softplus instance = new Softplus(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -68,7 +68,7 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softplus instance = new Softplus<>(tf); + Softplus instance = new Softplus(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java index 43591ab4761..bfff18fa9f7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SoftsignTest.java @@ -48,7 +48,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softsign instance = new Softsign<>(tf); + Softsign instance = new Softsign(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -71,7 +71,7 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Softsign instance = new Softsign<>(tf); + Softsign instance = new Softsign(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java index 7576789320b..d9add75d8e0 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/SwishTest.java @@ -21,8 +21,6 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import static org.junit.jupiter.api.Assertions.assertThrows; - /** @author Jim Clarke */ public class SwishTest { private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; @@ -41,8 +39,6 @@ public void setUp() {} @AfterEach public void tearDown() {} - - /** Test of Swish call method */ @Test public void testCallFloat() { @@ -60,7 +56,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Swish instance = new Swish<>(tf); + Swish instance = new Swish(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -83,7 +79,7 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Swish instance = new Swish<>(tf); + Swish instance = new Swish(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java index 5162e141c44..98104d552c7 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/activations/TanhTest.java @@ -52,7 +52,7 @@ public void testCallFloat() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Tanh instance = new Tanh<>(tf); + Tanh instance = new Tanh(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } @@ -71,7 +71,7 @@ public void testCallDouble() { for (TestSession.Mode tfMode : tfModes) try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); - Tanh instance = new Tanh<>(tf); + Tanh instance = new Tanh(tf); Operand result = instance.call(tf.constant(input)); session.evaluate(expected, result); } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java index 4e81e0620e6..a66b013a0c2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ConstantTest.java @@ -22,7 +22,6 @@ import org.tensorflow.types.*; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; /** Test the Constant initializer */ public class ConstantTest { @@ -51,7 +50,7 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xf); + Constant instance = new Constant(tf, tf.constant(0xf)); Operand operand = instance.call(tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } @@ -67,7 +66,7 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xf); + Constant instance = new Constant(tf, tf.constant(0xf)); Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } @@ -83,7 +82,7 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 0xffL); + Constant instance = new Constant(tf, tf.constant(0xffL)); Operand operand = instance.call(tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } @@ -97,7 +96,7 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 12.F); + Constant instance = new Constant(tf, tf.constant(12.f)); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -112,7 +111,7 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 11.); + Constant instance = new Constant(tf, tf.constant(11.)); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -121,6 +120,52 @@ public void testCallDouble() { /** Test of call method, of class Constant. */ @Test public void testCallString() { + for (TestSession.Mode tfMode : tfModes) + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(2, 2); + String[][] expected = { + {"Java Test", "Java Test"}, + {"Java Test", "Java Test"}, + }; + // There is no tf.constant(String[][]). + Operand expectedOp = + org.tensorflow.op.core.Constant.tensorOf(tf.scope(), expected); + + Constant instance = new Constant(tf, tf.constant("Java Test")); + Operand result = instance.call(tf.constant(shape), TString.class); + session.evaluate(expectedOp, result); + } + } + + /** Test of call method, of class Constant. */ + @Test + public void testCallStringInvalidDataType() { + for (TestSession.Mode tfMode : tfModes) + assertThrows( + org.tensorflow.exceptions.TFUnimplementedException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(2, 2); + String[][] expected = { + {"Java Test", "Java Test"}, + {"Java Test", "Java Test"}, + }; + // There is no tf.constant(String[][]). + Operand expectedOp = + org.tensorflow.op.core.Constant.tensorOf(tf.scope(), expected); + + Constant instance = new Constant(tf, tf.constant("Java Test")); + Operand result = instance.call(tf.constant(shape), TInt32.class); + session.run(result); // this will cause the exception to be thrown. + } + }); + } + + /** Test of call method, of class Constant. */ + @Test + public void testCallNonScalar() { for (TestSession.Mode tfMode : tfModes) assertThrows( java.lang.IllegalArgumentException.class, @@ -129,9 +174,8 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 22); - instance.call(tf.constant(shape), TString.class); - fail("IllegalArgumentException should have been thrown for TString"); + Constant instance = new Constant(tf, tf.constant(new int[] {1, 2})); + instance.call(tf.constant(shape), TInt32.class); } }); } @@ -145,7 +189,7 @@ public void testCallBool() { Shape shape = Shape.of(2, 2); Boolean[] expected = {true, true, true, true}; - Constant instance = new Constant<>(tf, true); + Constant instance = new Constant(tf, tf.constant(true)); Operand operand = instance.call(tf.constant(shape), TBool.class); session.evaluate(expected, operand); } @@ -158,7 +202,7 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Constant instance = new Constant<>(tf, 11.); + Constant instance = new Constant(tf, tf.constant(11.)); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java index e9769806928..c46206ad358 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/GlorotTest.java @@ -51,7 +51,7 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Glorot instance = new Glorot(tf, Distribution.TRUNCATED_NORMAL, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); @@ -68,7 +68,7 @@ public void testCallNormalDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Glorot instance = new Glorot(tf, Distribution.TRUNCATED_NORMAL, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -82,7 +82,7 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); + Glorot instance = new Glorot(tf, Distribution.UNIFORM, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -97,7 +97,7 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); + Glorot instance = new Glorot(tf, Distribution.UNIFORM, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -109,7 +109,7 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + Glorot instance = new Glorot(tf, Distribution.TRUNCATED_NORMAL, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); @@ -122,7 +122,7 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = new Glorot<>(tf, Distribution.UNIFORM, SEED); + Glorot instance = new Glorot(tf, Distribution.UNIFORM, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); @@ -135,8 +135,7 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Glorot instance = - new Glorot<>(tf, Distribution.NORMAL, SEED); + Glorot instance = new Glorot(tf, Distribution.NORMAL, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java index 8953fa3005e..2c9ef3fc56d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/HeTest.java @@ -27,7 +27,6 @@ public class HeTest { private static final long SEED = 1000L; private final TestSession.Mode[] tfModes = {TestSession.Mode.EAGER, TestSession.Mode.GRAPH}; - int counter; public HeTest() {} @@ -51,7 +50,7 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + He instance = new He(tf, Distribution.TRUNCATED_NORMAL, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -66,7 +65,7 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + He instance = new He(tf, Distribution.TRUNCATED_NORMAL, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -80,7 +79,7 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); + He instance = new He(tf, Distribution.UNIFORM, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -95,7 +94,7 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); + He instance = new He(tf, Distribution.UNIFORM, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -107,7 +106,7 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + He instance = new He(tf, Distribution.TRUNCATED_NORMAL, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); @@ -120,7 +119,7 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.UNIFORM, SEED); + He instance = new He(tf, Distribution.UNIFORM, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); @@ -133,7 +132,7 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - He instance = new He<>(tf, Distribution.NORMAL, SEED); + He instance = new He(tf, Distribution.NORMAL, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java index 6eee5473937..ef72422474d 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/IdentityTest.java @@ -21,10 +21,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; /** Test the Identity initializer */ public class IdentityTest { @@ -64,7 +60,7 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Identity instance = new Identity<>(tf, 2.); + Identity instance = new Identity(tf, 2.); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -90,7 +86,7 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Identity instance = new Identity<>(tf, 2.); + Identity instance = new Identity(tf, 2.); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -103,7 +99,7 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Identity instance = new Identity<>(tf, 2.); + Identity instance = new Identity(tf, 2.); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java index 336850a5549..b6a5fe2a947 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/LeCunTest.java @@ -51,7 +51,7 @@ public void testCallNormalFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + LeCun instance = new LeCun(tf, Distribution.TRUNCATED_NORMAL, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -66,7 +66,7 @@ public void testCallNormalDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + LeCun instance = new LeCun(tf, Distribution.TRUNCATED_NORMAL, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -80,7 +80,7 @@ public void testCallUniformFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); + LeCun instance = new LeCun(tf, Distribution.UNIFORM, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -95,7 +95,7 @@ public void testCallUniformDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); + LeCun instance = new LeCun(tf, Distribution.UNIFORM, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -107,7 +107,7 @@ public void testCallNormalReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.TRUNCATED_NORMAL, SEED); + LeCun instance = new LeCun(tf, Distribution.TRUNCATED_NORMAL, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); @@ -120,7 +120,7 @@ public void testCallUniformReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.UNIFORM, SEED); + LeCun instance = new LeCun(tf, Distribution.UNIFORM, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); @@ -133,7 +133,7 @@ public void testCallNORMALReproducible() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - LeCun instance = new LeCun<>(tf, Distribution.NORMAL, SEED); + LeCun instance = new LeCun(tf, Distribution.NORMAL, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java index 053ba5dd7ff..d37d1f6eb7f 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OnesTest.java @@ -51,7 +51,7 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); + Ones instance = new Ones(tf); Operand operand = instance.call(tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } @@ -65,7 +65,7 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); + Ones instance = new Ones(tf); Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } @@ -79,7 +79,7 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); + Ones instance = new Ones(tf); Operand operand = instance.call(tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } @@ -93,7 +93,7 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); + Ones instance = new Ones(tf); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -108,7 +108,7 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); + Ones instance = new Ones(tf); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -125,7 +125,7 @@ public void testCallString() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); + Ones instance = new Ones(tf); instance.call(tf.constant(shape), TString.class); fail("IllegalArgumentException should have been thrown for TString"); } @@ -140,7 +140,7 @@ public void testCallBool() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); + Ones instance = new Ones(tf); Operand operand = instance.call(tf.constant(shape), TBool.class); session.evaluate(expected, operand); } @@ -153,7 +153,7 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Ones instance = new Ones<>(tf); + Ones instance = new Ones(tf); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java index 22b89d9177c..0badd39db30 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/OrthogonalTest.java @@ -21,10 +21,6 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; -import org.tensorflow.types.TInt32; - -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.fail; /** Test the Orthogonal initializer */ public class OrthogonalTest { @@ -156,7 +152,7 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); + Orthogonal instance = new Orthogonal(tf, GAIN_VALUE, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -271,7 +267,7 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(10, 10); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); + Orthogonal instance = new Orthogonal(tf, GAIN_VALUE, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -284,7 +280,7 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Orthogonal instance = new Orthogonal<>(tf, GAIN_VALUE, SEED); + Orthogonal instance = new Orthogonal(tf, GAIN_VALUE, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java index 3b2b3bdb243..5c0811fbc5a 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomNormalTest.java @@ -14,7 +14,7 @@ =======================================================================*/ package org.tensorflow.framework.initializers; -import org.junit.jupiter.api.*; +import org.junit.jupiter.api.Test; import org.tensorflow.Operand; import org.tensorflow.framework.utils.TestSession; import org.tensorflow.ndarray.Shape; @@ -22,6 +22,8 @@ import org.tensorflow.types.TFloat32; import org.tensorflow.types.TFloat64; +import static org.junit.jupiter.api.Assertions.assertThrows; + /** Test the RandomNormal initializer */ public class RandomNormalTest { @@ -32,18 +34,6 @@ public class RandomNormalTest { public RandomNormalTest() {} - @BeforeAll - public static void setUpClass() {} - - @AfterAll - public static void tearDownClass() {} - - @BeforeEach - public void setUp() {} - - @AfterEach - public void tearDown() {} - /** Test of call method, of class RandomNormal. */ @Test public void testCalltestSoftmaxFloat() { @@ -52,8 +42,7 @@ public void testCalltestSoftmaxFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); + RandomNormal instance = new RandomNormal(tf, MEAN_VALUE, STDDEV_VALUE, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -68,8 +57,7 @@ public void testCalltestSoftmaxDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); + RandomNormal instance = new RandomNormal(tf, MEAN_VALUE, STDDEV_VALUE, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -82,11 +70,24 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomNormal instance = - new RandomNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); + RandomNormal instance = new RandomNormal(tf, MEAN_VALUE, STDDEV_VALUE, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); } } + + @Test + public void testInvalidStdDev() { + for (TestSession.Mode tfMode : tfModes) + assertThrows( + IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + + RandomNormal instance = new RandomNormal(tf, MEAN_VALUE, -2.5, SEED); + } + }); + } } diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java index 23e26083a9b..9ad6509d40c 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/RandomUniformTest.java @@ -53,8 +53,7 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); + RandomUniform instance = new RandomUniform(tf, MIN_VALUE, MAX_VALUE, SEED); Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } @@ -68,8 +67,7 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); + RandomUniform instance = new RandomUniform(tf, MIN_VALUE, MAX_VALUE, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -84,8 +82,7 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); + RandomUniform instance = new RandomUniform(tf, MIN_VALUE, MAX_VALUE, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -98,8 +95,7 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - RandomUniform instance = - new RandomUniform<>(tf, MIN_VALUE, MAX_VALUE, SEED); + RandomUniform instance = new RandomUniform(tf, MIN_VALUE, MAX_VALUE, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java index 96bf915e199..898cede46b2 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/TruncatedNormalTest.java @@ -52,8 +52,7 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); + TruncatedNormal instance = new TruncatedNormal(tf, MEAN_VALUE, STDDEV_VALUE, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -68,8 +67,7 @@ public void testCallDouble() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); + TruncatedNormal instance = new TruncatedNormal(tf, MEAN_VALUE, STDDEV_VALUE, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -82,8 +80,7 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - TruncatedNormal instance = - new TruncatedNormal<>(tf, MEAN_VALUE, STDDEV_VALUE, SEED); + TruncatedNormal instance = new TruncatedNormal(tf, MEAN_VALUE, STDDEV_VALUE, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java index 159affb07e2..5a2819e64c0 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/VarianceScalingTest.java @@ -50,8 +50,8 @@ public void testCallFloat1FanInTruncatedNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = - new VarianceScaling<>( + VarianceScaling instance = + new VarianceScaling( tf, 1.0, VarianceScaling.Mode.FAN_IN, @@ -71,8 +71,8 @@ public void testCallDouble1FanInTruncatedNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = - new VarianceScaling<>( + VarianceScaling instance = + new VarianceScaling( tf, 1.0, VarianceScaling.Mode.FAN_IN, @@ -91,13 +91,9 @@ public void testCallFloat1FanInNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = - new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); + VarianceScaling instance = + new VarianceScaling( + tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -112,13 +108,9 @@ public void testCalltestSoftmaxDouble1FanInNormal() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = - new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); + VarianceScaling instance = + new VarianceScaling( + tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -132,8 +124,8 @@ public void testCalltestSoftmaxFloat1FanInUNIFORM() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = - new VarianceScaling<>( + VarianceScaling instance = + new VarianceScaling( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); @@ -149,8 +141,8 @@ public void testCalltestSoftmaxDouble1FanInUNIFORM() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = - new VarianceScaling<>( + VarianceScaling instance = + new VarianceScaling( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); @@ -164,8 +156,8 @@ public void testReproducible1() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = - new VarianceScaling<>( + VarianceScaling instance = + new VarianceScaling( tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.UNIFORM, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); @@ -180,13 +172,9 @@ public void testReproducible2() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = - new VarianceScaling<>( - tf, - 1.0, - VarianceScaling.Mode.FAN_IN, - VarianceScaling.Distribution.NORMAL, - SEED); + VarianceScaling instance = + new VarianceScaling( + tf, 1.0, VarianceScaling.Mode.FAN_IN, VarianceScaling.Distribution.NORMAL, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2); @@ -200,8 +188,8 @@ public void testReproducible3() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = - new VarianceScaling<>( + VarianceScaling instance = + new VarianceScaling( tf, 1.0, VarianceScaling.Mode.FAN_OUT, @@ -220,8 +208,8 @@ public void testReproducible4() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - VarianceScaling instance = - new VarianceScaling<>( + VarianceScaling instance = + new VarianceScaling( tf, 1.0, VarianceScaling.Mode.FAN_AVG, VarianceScaling.Distribution.UNIFORM, SEED); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); diff --git a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java index 21bad6ff360..f81df29dae1 100644 --- a/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java +++ b/tensorflow-framework/src/test/java/org/tensorflow/framework/initializers/ZerosTest.java @@ -21,6 +21,8 @@ import org.tensorflow.op.Ops; import org.tensorflow.types.*; +import static org.junit.jupiter.api.Assertions.assertThrows; + /** Test the Zeros initializer */ public class ZerosTest { @@ -48,7 +50,7 @@ public void testCallUInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); + Zeros instance = new Zeros(tf); Operand operand = instance.call(tf.constant(shape), TUint8.class); session.evaluate(expected, operand); } @@ -62,7 +64,7 @@ public void testCallInt() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); + Zeros instance = new Zeros(tf); Operand operand = instance.call(tf.constant(shape), TInt32.class); session.evaluate(expected, operand); } @@ -76,7 +78,7 @@ public void testCallLong() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); + Zeros instance = new Zeros(tf); Operand operand = instance.call(tf.constant(shape), TInt64.class); session.evaluate(expected, operand); } @@ -90,7 +92,7 @@ public void testCallFloat() { try (TestSession session = TestSession.createTestSession(tfMode)) { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); + Zeros instance = new Zeros(tf); Operand operand = instance.call(tf.constant(shape), TFloat32.class); session.evaluate(expected, operand); } @@ -105,7 +107,7 @@ public void testCallDouble() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); + Zeros instance = new Zeros(tf); Operand operand = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(expected, operand); } @@ -115,14 +117,18 @@ public void testCallDouble() { @Test public void testCallString() { for (TestSession.Mode tfMode : tfModes) - try (TestSession session = TestSession.createTestSession(tfMode)) { - Ops tf = session.getTF(); - Shape shape = Shape.of(2, 2); - - Zeros instance = new Zeros<>(tf); - Operand operand = instance.call(tf.constant(shape), TString.class); - session.evaluateString(operand, String::isEmpty); - } + assertThrows( + java.lang.IllegalArgumentException.class, + () -> { + try (TestSession session = TestSession.createTestSession(tfMode)) { + Ops tf = session.getTF(); + Shape shape = Shape.of(2, 2); + + Zeros instance = new Zeros(tf); + Operand operand = instance.call(tf.constant(shape), TString.class); + session.evaluateString(operand, String::isEmpty); + } + }); } /** Test of call method, of class Zeros. */ @@ -134,7 +140,7 @@ public void testCallBool() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); + Zeros instance = new Zeros(tf); Operand operand = instance.call(tf.constant(shape), TBool.class); session.evaluate(expected, operand); } @@ -147,7 +153,7 @@ public void testReproducible() { Ops tf = session.getTF(); Shape shape = Shape.of(2, 2); - Zeros instance = new Zeros<>(tf); + Zeros instance = new Zeros(tf); Operand operand1 = instance.call(tf.constant(shape), TFloat64.class); Operand operand2 = instance.call(tf.constant(shape), TFloat64.class); session.evaluate(operand1, operand2);