diff --git a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
index ea3ef31313e..acbae4dac6b 100644
--- a/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
+++ b/tensorflow-core/tensorflow-core-api/src/gen/annotations/org/tensorflow/op/Ops.java
@@ -354,20 +354,20 @@ public final class Ops {
public final SparseOps sparse;
- public final TpuOps tpu;
-
public final BitwiseOps bitwise;
+ public final TpuOps tpu;
+
public final MathOps math;
public final AudioOps audio;
public final SignalOps signal;
- public final TrainOps train;
-
public final QuantizationOps quantization;
+ public final TrainOps train;
+
private final Scope scope;
private Ops(Scope scope) {
@@ -385,13 +385,13 @@ private Ops(Scope scope) {
random = new RandomOps(this);
strings = new StringsOps(this);
sparse = new SparseOps(this);
- tpu = new TpuOps(this);
bitwise = new BitwiseOps(this);
+ tpu = new TpuOps(this);
math = new MathOps(this);
audio = new AudioOps(this);
signal = new SignalOps(this);
- train = new TrainOps(this);
quantization = new QuantizationOps(this);
+ train = new TrainOps(this);
}
/**
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java
index e1482a51a8a..708c7daead6 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Activation.java
@@ -18,14 +18,7 @@
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;
-/**
- * Abstract base class for Activations
- *
- *
Note: The {@link #tf} attribute must be set prior to invoking the call method. See
- * {@link #setTF(Ops)} and the constructor {@link #Activation(Ops)}.
- *
- * @param the data type of the activation
- */
+/** Abstract base class for Activations */
public abstract class Activation {
/** The TensorFlow Ops */
@@ -41,28 +34,29 @@ protected Activation(Ops tf) {
}
/**
- * Sets the TensorFlow Ops
+ * Gets the TensorFlow Ops
*
- * @param tf the TensorFlow Ops
+ * @return the TensorFlow Ops
*/
- protected void setTF(Ops tf) {
- this.tf = tf;
+ protected Ops getTF() {
+ return this.tf;
}
/**
- * Gets the TensorFlow Ops
+ * Sets the TensorFlow Ops
*
- * @return the TensorFlow Ops
+ * @param tf the TensorFlow Ops
*/
- protected Ops getTF() {
- return this.tf;
+ protected void setTF(Ops tf) {
+ this.tf = tf;
}
/**
* Gets the calculation operation for the activation.
*
* @param input the input tensor
+ * @param the data type of the input and result
* @return The operand for the activation
*/
- public abstract Operand call(Operand input);
+ public abstract Operand call(Operand input);
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java
index 2f2f16f2752..00c720c936b 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ELU.java
@@ -44,11 +44,10 @@
* Operand<TFloat32> result = elu.call(input);
*
*
- * @param the data type of the activation
* @see Clevert et al, 2016, Fast and Accurate Deep
* Network Learning by Exponential Linear Units (ELUs)
*/
-public class ELU extends Activation {
+public class ELU extends Activation {
private static final double ALPHA_DEFAULT = 1.0;
@@ -76,20 +75,16 @@ public ELU(Ops tf, double alpha) {
this.alpha = alpha;
}
- /**
- * Gets the calculation operation for the activation.
- *
- * @param input the input tensor
- * @return The operand for the activation
- */
+ /** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
+ public Operand call(Operand input) {
- Operand result = tf.nn.elu(input);
- if (alpha == 1.0) return result;
- else {
- Class inputType = input.type();
- Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType));
+ Operand result = tf.nn.elu(input);
+ if (alpha == 1.0) {
+ return result;
+ } else {
+ Class inputType = input.type();
+ Operand y = tf.math.mul(result, tf.dtypes.cast(tf.constant(alpha), inputType));
Operand cond = tf.math.greater(result, tf.dtypes.cast(tf.constant(0), inputType));
return tf.select(cond, result, y);
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java
index d5fdff36c61..512f96713aa 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Exponential.java
@@ -30,10 +30,8 @@
* Operand<TFloat32> result = exp.call(input);
* // result is [0.04978707f, 0.36787945f, 1.f, 2.7182817f, 20.085537f]
*
- *
- * @param the data type of the activation
*/
-public class Exponential extends Activation {
+public class Exponential extends Activation {
/**
* Creates an Exponential activation.
@@ -48,10 +46,12 @@ public Exponential(Ops tf) {
* Calculates the Exponential activation.
*
* @param input the input tensor
+ * @param the data type of the input and result
* @return an Operand for the exponential activation: exp(x)
.
*/
@Override
- public Operand call(Operand input) {
+ public Operand call(Operand input) {
+
return tf.math.exp(input);
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/GeLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/GeLU.java
new file mode 100644
index 00000000000..abddbd512cf
--- /dev/null
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/GeLU.java
@@ -0,0 +1,116 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+=======================================================================*/
+package org.tensorflow.framework.activations;
+
+import org.tensorflow.Operand;
+import org.tensorflow.op.Ops;
+import org.tensorflow.types.family.TFloating;
+
+import static org.tensorflow.framework.utils.CastHelper.cast;
+
+/**
+ * Applies the Gaussian error linear unit (GELU) activation function.
+ *
+ * Gaussian error linear unit (GELU) computes {@code x * P(X <= x)}, where {@code P(X) ~ N(0,
+ * 1)}. The (GELU) nonlinearity weights inputs by their value, rather than gates inputs by their
+ * sign as in ReLU. if approximate
is true
:
+ *
+ *
+ * 0.5 * x * (1 + tanh(sqrt(2 / pi) * (x + 0.044715 * x^3)))
+ *
+ *
+ * or, if approximate
is false
.
+ *
+ *
+ * x * P(X <= x) = 0.5 * x * (1 + erf(x / sqrt(2))),
+ *
+ *
+ * where P(X) ~ N(0, 1)
.
+ *
+ * @see Hendrycks, Dan and Gimpel, Kevin, 2016-2020,
+ * Gaussian Error Linear Units (GELUs)
+ */
+public class GeLU extends Activation {
+
+ private final boolean approximate;
+
+ /**
+ * Creates a e Gaussian error linear unit (GELU) activation, with approximate set to false
+ *
+ * @param tf The TensorFlow ops
+ */
+ public GeLU(Ops tf) {
+ this(tf, false);
+ }
+
+ /**
+ * Creates a e Gaussian error linear unit (GELU) activation
+ *
+ * @param tf The TensorFlow ops
+ * @param approximate indicator whether to enable approximation.
+ */
+ public GeLU(Ops tf, boolean approximate) {
+ super(tf);
+ this.approximate = approximate;
+ }
+
+ /** {@inheritDoc} */
+ @Override
+ public Operand call(Operand input) {
+
+ Operand point5 = cast(tf, tf.constant(0.5), input.type());
+ Operand one = cast(tf, tf.constant(1.0), input.type());
+
+ if (approximate) {
+ /*
+ coeff = math_ops.cast(0.044715, features.dtype)
+ return 0.5 * features * (
+ 1.0 + math_ops.tanh(0.7978845608028654 *
+ (features + coeff * math_ops.pow(features, 3))))
+ */
+ Operand coeff = cast(tf, tf.constant(0.044715), input.type());
+ // sqrt(2.0 / PI)
+ Operand sqrtTwoDivPI = cast(tf, tf.constant(0.7978845608028654), input.type());
+ Operand three = cast(tf, tf.constant(3), input.type());
+
+ return tf.math.mul(
+ point5,
+ tf.math.mul(
+ input,
+ tf.math.add(
+ one,
+ tf.math.tanh(
+ tf.math.mul(
+ sqrtTwoDivPI,
+ tf.math.add(
+ input, tf.math.mul(coeff, tf.math.pow(input, three)) // mul
+ ) // add
+ ) // mul
+ ) // tanh
+ ) // add
+ ) // mul
+ ); // mul
+
+ } else {
+ /*
+ return 0.5 * features * (1.0 + math_ops.erf(
+ features / math_ops.cast(1.4142135623730951, features.dtype)))
+ */
+ Operand sqrtTwo = cast(tf, tf.constant(1.4142135623730951), input.type());
+ return tf.math.mul(
+ point5, tf.math.mul(input, tf.math.add(one, tf.math.erf(tf.math.div(input, sqrtTwo)))));
+ }
+ }
+}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java
index 0b7cf573b8e..6c7323a90cb 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/HardSigmoid.java
@@ -16,7 +16,7 @@
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
-import org.tensorflow.types.family.TFloating;
+import org.tensorflow.types.family.TNumber;
/**
* Hard sigmoid activation.
@@ -40,10 +40,8 @@
* Operand<TFloat32> result = hardSigmoid.call(input);
* // result is [0.f , 0.3f, 0.5f, 0.7f, 1.f]
*
- *
- * @param the data type of the result
*/
-public class HardSigmoid extends Activation {
+public class HardSigmoid extends Activation {
/**
* Creates Hard sigmoid activation.
@@ -54,19 +52,14 @@ public HardSigmoid(Ops tf) {
super(tf);
}
- /**
- * Gets the calculation operation for the activation.
- *
- * @param input the input tensor
- * @return The operand for the activation
- */
+ /** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
- Class inputType = input.type();
- Operand point2 = tf.dtypes.cast(tf.constant(0.2), inputType);
- Operand point5 = tf.dtypes.cast(tf.constant(0.5), inputType);
+ public Operand call(Operand input) {
+ Class inputType = input.type();
+ Operand point2 = tf.dtypes.cast(tf.constant(0.2), inputType);
+ Operand point5 = tf.dtypes.cast(tf.constant(0.5), inputType);
- Operand x = tf.math.add(tf.math.mul(input, point2), point5);
+ Operand x = tf.math.add(tf.math.mul(input, point2), point5);
return tf.clipByValue(
x, tf.dtypes.cast(tf.constant(0), inputType), tf.dtypes.cast(tf.constant(1), inputType));
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java
index d907397995d..dcda76db6bf 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Linear.java
@@ -19,9 +19,9 @@
import org.tensorflow.types.family.TNumber;
/**
- * Linear activation function (pass-through).
+ * Linear activation function (pass-through).
*
- * The linear activation returns its input. It is also known as the Identity activation function.
+ * The linear activation returns its input. It is also known as the Identity activation function.
*
*
For example:
*
@@ -33,7 +33,7 @@
* // result is [-3.0f,-1.0f, 0.0f,1.0f,3.0f]
*
*/
-public class Linear extends Activation {
+public class Linear extends Activation {
/**
* Creates a linear activation.
@@ -46,7 +46,7 @@ public Linear(Ops tf) {
/** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
+ public Operand call(Operand input) {
return input;
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java
index aef6ebf2992..409015a1203 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/ReLU.java
@@ -55,10 +55,8 @@
* result = relu.call(input);
* // result is [-0.f, -0.f, 0.f, 0.f, 10.f]
*
- *
- * @param the data type of the result
*/
-public class ReLU extends Activation {
+public class ReLU extends Activation {
public static final float ALPHA_DEFAULT = 0.0f;
public static final float MAX_VALUE_DEFAULT = Float.NaN;
@@ -96,11 +94,11 @@ public ReLU(Ops tf, float alpha, float maxValue, float threshold) {
/** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
- Class inputType = input.type();
+ public Operand call(Operand input) {
+ Class inputType = input.type();
boolean clipMax = !Float.isNaN(maxValue);
- Operand negativePart = null;
+ Operand negativePart = null;
if (alpha != 0) {
if (Float.isNaN(maxValue) && threshold == 0) {
return tf.nn.leakyRelu(input, LeakyRelu.alpha(alpha));
@@ -114,7 +112,7 @@ public Operand call(Operand input) {
}
}
- Operand lInput;
+ Operand lInput;
if (threshold != 0) {
// computes input for input > threshold else 0
Greater greater = tf.math.greater(input, tf.dtypes.cast(tf.constant(threshold), inputType));
@@ -127,8 +125,8 @@ public Operand call(Operand input) {
lInput = tf.nn.relu(input);
}
if (clipMax) {
- Operand lmaxValue = tf.dtypes.cast(tf.constant(maxValue), inputType);
- Operand zero = tf.dtypes.cast(tf.constant(0), inputType);
+ Operand lmaxValue = tf.dtypes.cast(tf.constant(maxValue), inputType);
+ Operand zero = tf.dtypes.cast(tf.constant(0), inputType);
lInput = tf.clipByValue(lInput, zero, lmaxValue);
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java
index f24731049fb..7955144abbb 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/SELU.java
@@ -42,10 +42,9 @@
* Notes: To be used together with the {@link
* org.tensorflow.framework.initializers.LeCun} initializer with Normal Distribution.
*
- * @param the data type of the activation
* @see Klambauer et al., 2017
*/
-public class SELU extends Activation {
+public class SELU extends Activation {
/**
* Creates a Scaled Exponential Linear Unit (SELU) activation.
@@ -56,14 +55,9 @@ public SELU(Ops tf) {
super(tf);
}
- /**
- * Gets the calculation operation for the activation.
- *
- * @param input the input tensor
- * @return The operand for the activation
- */
+ /** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
+ public Operand call(Operand input) {
return tf.nn.selu(input);
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java
index 5d507b38483..a89b6119d02 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Sigmoid.java
@@ -38,10 +38,8 @@
* // result is [2.0611537e-09f, 2.6894143e-01f,
* // 5.0000000e-01f,7.3105860e-01f, 1.f]
*
- *
- * @param the data type of the activation
*/
-public class Sigmoid extends Activation {
+public class Sigmoid extends Activation {
/**
* Creates a Sigmoid activation.
@@ -52,14 +50,9 @@ public Sigmoid(Ops tf) {
super(tf);
}
- /**
- * Gets the calculation operation for the activation.
- *
- * @param input the input tensor
- * @return The operand for the activation
- */
+ /** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
+ public Operand call(Operand input) {
return tf.math.sigmoid(input);
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java
index 154e1ecc84a..309cdd68b8b 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softmax.java
@@ -35,10 +35,8 @@
* The softmax of each vector x is computed as: exp(x) / tf.sum(exp(x))
.
*
*
The input values in are the log-odds of the resulting probability.
- *
- * @param the data type of the activation
*/
-public class Softmax extends Activation {
+public class Softmax extends Activation {
private static final int AXIS_DEFAULT = -1;
@@ -65,23 +63,18 @@ public Softmax(Ops tf, int axis) {
this.axis = axis;
}
- /**
- * Gets the calculation operation for the activation.
- *
- * @param input the input tensor
- * @return The operand for the activation
- */
+ /** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
+ public Operand call(Operand input) {
Shape shape = input.shape();
int numDimensions = shape.numDimensions();
if (numDimensions == 2) {
return tf.nn.softmax(input);
} else {
- Operand e =
+ Operand e =
tf.math.exp(
tf.math.sub(input, tf.reduceMax(input, tf.constant(axis), ReduceMax.keepDims(true))));
- Operand s = tf.reduceSum(e, tf.constant(axis), ReduceSum.keepDims(true));
+ Operand s = tf.reduceSum(e, tf.constant(axis), ReduceSum.keepDims(true));
return tf.math.div(e, s);
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java
index 65a183ea047..0eb703aad9f 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softplus.java
@@ -32,7 +32,7 @@
* // 1.3132616e+00f, 2.0000000e+01f]
*
*/
-public class Softplus extends Activation {
+public class Softplus extends Activation {
/**
* Creates a Softplus activation function.
@@ -43,14 +43,9 @@ public Softplus(Ops tf) {
super(tf);
}
- /**
- * Gets the calculation operation for the activation.
- *
- * @param input the input tensor
- * @return The operand for the activation
- */
+ /** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
+ public Operand call(Operand input) {
return tf.math.softplus(input);
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java
index 1f691e71862..0a7754258df 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Softsign.java
@@ -30,10 +30,8 @@
* Operand<TFloat32> result = softsign.call(input);
* // result is [-0.5f, 0.f, 0.5f]
*
- *
- * @param the data type of the activation
*/
-public class Softsign extends Activation {
+public class Softsign extends Activation {
/**
* Creates a Softsign activation.
@@ -44,14 +42,9 @@ public Softsign(Ops tf) {
super(tf);
}
- /**
- * Gets the calculation operation for the activation.
- *
- * @param input the input tensor
- * @return The operand for the activation
- */
+ /** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
+ public Operand call(Operand input) {
return tf.nn.softsign(input);
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java
index d9f73a422d5..eb43a21f285 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Swish.java
@@ -37,10 +37,9 @@
*
*
*
- * @param the data type of the activation
* @see Ramachandran et al., 2017
*/
-public class Swish extends Activation {
+public class Swish extends Activation {
/**
* Creates a Swish activation, swish(x) = x * sigmoid(x)
.
@@ -57,7 +56,7 @@ public Swish(Ops tf) {
/** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
+ public Operand call(Operand input) {
// TODO Python Keras returns a "grad", which is an optimization not implemented in Java.
return tf.math.mul(input, tf.math.sigmoid(input));
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java
index 4fe02eed048..145561b9129 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/activations/Tanh.java
@@ -30,10 +30,8 @@
* Operand<TFloat32> result = tanh.call(input);
* // result = [-0.9950547f, -0.7615942f, 0.f, 0.7615942f, 0.9950547f]
*
- *
- * @param the data type of the activation
*/
-public class Tanh extends Activation {
+public class Tanh extends Activation {
/**
* Creates a Hyperbolic tangent activation.
@@ -46,7 +44,7 @@ public Tanh(Ops tf) {
/** {@inheritDoc} */
@Override
- public Operand call(Operand input) {
+ public Operand call(Operand input) {
return tf.math.tanh(input);
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java
index d3094b5e9e9..739f9f55c55 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/constraints/Constraint.java
@@ -41,7 +41,9 @@ public Constraint(Ops tf) {
* Applies the constraint against the provided weights
*
* @param weights the weights
+ * @param the data the weights and result
* @return the constrained weights
+ *
*/
public abstract Operand call(Operand weights);
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java
index 4a2df86d74b..3f2ebe58cb4 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Constant.java
@@ -16,11 +16,11 @@
import org.tensorflow.Operand;
import org.tensorflow.op.Ops;
-import org.tensorflow.types.TBool;
import org.tensorflow.types.TInt64;
-import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;
+import static org.tensorflow.framework.utils.CastHelper.cast;
+
/**
* Initializer that generates tensors with a constant value.
*
@@ -33,76 +33,39 @@
* initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class);
*
*
- * @param The Type for the call operation
+ * Only scalar values are allowed. The constant value provided must be convertible to the data
+ * type requested when calling the initializer.
*/
-public class Constant extends BaseInitializer {
-
- private final double doubleValue;
- private final long longValue;
- private final boolean booleanValue;
- private final ValueType valueType;
+public class Constant extends BaseInitializer {
- /**
- * Creates an Initializer that generates tensors with a constant value.
- *
- * @param tf the TensorFlow Ops
- * @param value a long value used for the constant.
- */
- public Constant(Ops tf, long value) {
- super(tf);
- longValue = value;
- doubleValue = 0;
- booleanValue = false;
- valueType = ValueType.LONG;
- }
+ private final Operand extends TType> value;
/**
* Creates an Initializer that generates tensors with a constant value.
*
* @param tf the TensorFlow Ops
- * @param value a double value used for the constant.
+ * @param value the value used for the constant.
+ * @throws IllegalArgumentException if value is not a scalar.
*/
- public Constant(Ops tf, double value) {
+ public Constant(Ops tf, Operand extends TType> value) {
super(tf);
- doubleValue = value;
- longValue = 0;
- booleanValue = false;
- valueType = ValueType.DOUBLE;
+ if (!value.shape().isScalar()) {
+ throw new IllegalArgumentException("value must be scalar, got shape : " + value.shape());
+ }
+ this.value = value;
}
/**
- * Creates an Initializer that generates tensors with a constant value.
+ * Generates the operation used to perform the initialization.
*
- * @param tf the TensorFlow Ops
- * @param value a boolean value used for the constant.
+ * @param dims the shape dimensions
+ * @param type the data type of tensor
+ * @param The data Type for initializer operation
+ * @return An operand for the initialization.
*/
- public Constant(Ops tf, boolean value) {
- super(tf);
- booleanValue = value;
- doubleValue = 0;
- longValue = 0;
- valueType = ValueType.BOOLEAN;
- }
-
- /** {@inheritDoc} */
@Override
- public Operand call(Operand dims, Class type) {
- if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) {
- throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName());
- }
- switch (valueType) {
- case LONG:
- return tf.fill(dims, tf.dtypes.cast(tf.constant(longValue), type));
- case DOUBLE:
- return tf.fill(dims, tf.dtypes.cast(tf.constant(doubleValue), type));
- default:
- return tf.fill(dims, tf.dtypes.cast(tf.constant(booleanValue), type));
- }
- }
+ public Operand call(Operand dims, Class type) {
- private enum ValueType {
- LONG,
- DOUBLE,
- BOOLEAN
+ return tf.fill(dims, cast(tf, value, type));
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java
index 894bd073758..5a3c291785f 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Glorot.java
@@ -16,7 +16,6 @@
package org.tensorflow.framework.initializers;
import org.tensorflow.op.Ops;
-import org.tensorflow.types.family.TFloating;
/**
* The Glorot initializer, also called Xavier initializer.
@@ -58,16 +57,17 @@
*
*
* NOTE:
+ *
*
For a GlorotNormal equivalent initializer, use {@link
* VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter.
+ *
*
For a GlorotUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM}
* for the distribution parameter.
*
- * @param The TType for the call operation
* @see VarianceScaling.Distribution
* @see Glorot et al., 2010
*/
-public class Glorot extends VarianceScaling {
+public class Glorot extends VarianceScaling {
public static final double SCALE = 1.0;
@@ -77,7 +77,7 @@ public class Glorot extends VarianceScaling {
* @param tf the TensorFlow Ops
* @param distribution The distribution type for the Glorot initializer.
* @param seed the seed for random number generation. An initializer created with a given seed
- * will always produce the same random tensor for a given shape and dtype.
+ * will always produce the same random tensor for a given shape and data type.
* @see VarianceScaling.Distribution
*/
public Glorot(Ops tf, Distribution distribution, long seed) {
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java
index 3a91b72b0d0..ac64e449265 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/He.java
@@ -15,7 +15,6 @@
package org.tensorflow.framework.initializers;
import org.tensorflow.op.Ops;
-import org.tensorflow.types.family.TFloating;
/**
* He initializer.
@@ -53,17 +52,18 @@
*
*
* NOTE:
+ *
*
For an HeNormal equivalent initializer, use {@link
* VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter.
- *
For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM}
- * for the distribution parameter.
*
- * @param The TType for the call operation
+ * For an HeUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} for
+ * the distribution parameter.
+ *
* @see He
* et al., 2015
*/
-public class He extends VarianceScaling {
+public class He extends VarianceScaling {
public static final double SCALE = 2.0;
@@ -73,7 +73,7 @@ public class He extends VarianceScaling {
* @param tf the TensorFlow Ops
* @param distribution The distribution type for the He initializer.
* @param seed the seed for random number generation. An initializer created with a given seed
- * will always produce the same random tensor for a given shape and dtype.
+ * will always produce the same random tensor for a given shape and data type.
* @see VarianceScaling.Distribution
*/
public He(Ops tf, Distribution distribution, long seed) {
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java
index f672c9f1e85..2b1ccf00bae 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Identity.java
@@ -21,6 +21,8 @@
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TFloating;
+import static org.tensorflow.framework.utils.CastHelper.cast;
+
/**
* Initializer that generates the identity matrix.
*
@@ -34,10 +36,8 @@
* Operand<TFloat32> values =
* initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class);
*
- *
- * @param The TType for the call operation
*/
-public class Identity extends BaseInitializer {
+public class Identity extends BaseInitializer {
public static final double GAIN_DEFAULT = 1.0;
private final double gain;
@@ -65,7 +65,7 @@ public Identity(Ops tf, double gain) {
/** {@inheritDoc} */
@Override
- public Operand call(Operand dims, Class type) {
+ public Operand call(Operand dims, Class type) {
Shape shape = ShapeUtils.toShape(tf.scope(), dims);
if (shape.numDimensions() != 2) {
throw new IllegalArgumentException("2D matrix required, got " + shape.numDimensions());
@@ -74,10 +74,10 @@ public Operand call(Operand dims, Class type) {
long diagSize = Math.min(shape.size(0), shape.size(1));
Shape diagShape = Shape.of(diagSize);
- Operand op;
- Operand zero = tf.dtypes.cast(tf.constant(0), type);
- Operand diagOnes =
- tf.fill(tf.constant(diagShape.asArray()), tf.dtypes.cast(tf.constant(1.0), type));
+ Operand op;
+ Operand zero = cast(tf, tf.constant(0), type);
+ Operand diagOnes =
+ tf.fill(tf.constant(diagShape.asArray()), cast(tf, tf.constant(1.0), type));
if (isSquare) {
op =
tf.linalg.matrixDiag(
@@ -87,10 +87,10 @@ public Operand call(Operand dims, Class type) {
tf.constant((int) shape.size(1)),
zero);
} else {
- Operand zeroMatrix = tf.zeros(dims, type);
+ Operand zeroMatrix = tf.zeros(dims, type);
op = tf.linalg.matrixSetDiag(zeroMatrix, diagOnes, tf.constant(0));
}
- return tf.math.mul(op, tf.dtypes.cast(tf.constant(gain), type));
+ return tf.math.mul(op, cast(tf, tf.constant(gain), type));
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java
index 4beb218783b..2f30bc5e99d 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Initializer.java
@@ -18,19 +18,16 @@
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TType;
-/**
- * An interface for Initializers
- *
- * @param The data Type for initializer operation
- */
+/** An interface for Initializers */
public interface Initializer {
/**
* Generates the operation used to perform the initialization.
*
* @param dims the shape dimensions
- * @param type the type of tensor
+ * @param type the data type of tensor
+ * @param The data Type for initializer operation
* @return An operand for the initialization.
*/
- Operand call(Operand dims, Class type);
+ Operand call(Operand dims, Class type);
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java
index 38e68ef688b..b82f40918c0 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/LeCun.java
@@ -15,7 +15,6 @@
package org.tensorflow.framework.initializers;
import org.tensorflow.op.Ops;
-import org.tensorflow.types.family.TFloating;
/**
* LeCun normal initializer.
@@ -27,7 +26,7 @@
* stddev = sqrt(1 / fanIn) where fanIn
is the number of input units in the
* weight tensor.
*
- * If the distribution is UNIFORM, itraws samples from a uniform distribution within
+ * If the distribution is UNIFORM, it draws samples from a uniform distribution within
* [-limit, limit]
, where limit = Math.sqrt(3 / fanIn)
(fanIn
is
* the number of input units in the weight tensor)
*
@@ -59,14 +58,14 @@
*
*
NOTE: *
*
- *
For a LeCunNormal equivalent initializer, use {@link VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. *
+ *
For a LeCunNormal equivalent initializer, use {@link
+ * VarianceScaling.Distribution#TRUNCATED_NORMAL} for the distribution parameter. *
*
*
For a LeCunUniform equivalent initializer, use {@link VarianceScaling.Distribution#UNIFORM} *
* for the distribution parameter. *
*
*
*
- * @param The TType for the call operation
* @see Self-Normalizing
* Neural Networks, Klambauer et al., 2017
@@ -74,7 +73,7 @@
* al., 1998
* @see VarianceScaling.Distribution
*/
-public class LeCun extends VarianceScaling {
+public class LeCun extends VarianceScaling {
/**
* Creates a LeCunNormal Initializer
@@ -82,7 +81,7 @@ public class LeCun extends VarianceScaling {
* @param tf the TensorFlow Ops
* @param distribution The distribution type for the Glorot initializer.
* @param seed the seed for random number generation. An initializer created with a given seed
- * will always produce the same random tensor for a given shape and dtype.
+ * will always produce the same random tensor for a given shape and data type.
*/
public LeCun(Ops tf, Distribution distribution, long seed) {
super(tf, 1.0, Mode.FAN_IN, distribution, seed);
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java
index b8eb0c418e9..5cf8c7a8033 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Ones.java
@@ -21,6 +21,8 @@
import org.tensorflow.types.family.TNumber;
import org.tensorflow.types.family.TType;
+import static org.tensorflow.framework.utils.CastHelper.cast;
+
/**
* Initializer that generates tensors initialized to 1.
*
@@ -32,10 +34,8 @@
* Operand<TFloat32> values =
* initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class);
*
- *
- * @param The TType for the call operation
*/
-public class Ones extends BaseInitializer {
+public class Ones extends BaseInitializer {
/**
* Creates an Initializer that sets all values to one.
@@ -55,12 +55,22 @@ public Ones(Ops tf) {
super(tf);
}
- /** {@inheritDoc} */
+ /**
+ * Generates the operation used to perform the initialization.
+ *
+ * @param dims the shape dimensions
+ * @param type the data type of tensor
+ * @param The data Type for initializer operation
+ * @return An operand for the initialization.
+ * @throws IllegalArgumentException if the data type is not a TNumber or TBool
+ */
@Override
- public Operand call(Operand dims, Class type) {
+ public Operand call(Operand dims, Class type) {
if (!TNumber.class.isAssignableFrom(type) && type != TBool.class) {
- throw new IllegalArgumentException("Tensor type must be numeric or boolean: " + type.getSimpleName());
+ throw new IllegalArgumentException(
+ "Tensor type must be numeric or boolean: " + type.getSimpleName());
}
- return tf.fill(dims, tf.dtypes.cast(tf.constant(1.0), type));
+
+ return tf.fill(dims, cast(tf, tf.constant(1), type));
}
}
diff --git a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java
index a5b466e118e..14f1049d038 100644
--- a/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java
+++ b/tensorflow-framework/src/main/java/org/tensorflow/framework/initializers/Orthogonal.java
@@ -23,6 +23,8 @@
import org.tensorflow.types.TInt64;
import org.tensorflow.types.family.TFloating;
+import static org.tensorflow.framework.utils.CastHelper.cast;
+
/**
* Initializer that generates an orthogonal matrix.
*
@@ -44,10 +46,8 @@
* Operand<TFloat32> values =
* initializer.call(tf.constant(Shape.of(2,2)), TFloat32.class);
*
- *
- * @param The TType for the call operation
*/
-public class Orthogonal extends BaseInitializer {
+public class Orthogonal extends BaseInitializer {
public static final double GAIN_DEFAULT = 1.0;
@@ -59,7 +59,7 @@ public class Orthogonal extends BaseInitializer {
*
* @param tf the TensorFlow Ops
* @param seed the seed for random number generation. An initializer created with a given seed
- * will always produce the same random tensor for a given shape and dtype.
+ * will always produce the same random tensor for a given shape and data type.
*/
public Orthogonal(Ops tf, long seed) {
this(tf, GAIN_DEFAULT, seed);
@@ -71,7 +71,7 @@ public Orthogonal(Ops tf, long seed) {
* @param tf the TensorFlow Ops
* @param gain the gain to be applied to the Matrix.
* @param seed the seed for random number generation. An initializer created with a given seed
- * will always produce the same random tensor for a given shape and dtype.
+ * will always produce the same random tensor for a given shape and data type.
*/
public Orthogonal(Ops tf, double gain, long seed) {
super(tf);
@@ -81,7 +81,8 @@ public Orthogonal(Ops tf, double gain, long seed) {
/** {@inheritDoc} */
@Override
- public Operand call(Operand dims, Class type) {
+ public Operand call(Operand dims, Class type) {
+
Shape dimsShape = ShapeUtils.toShape(tf.scope(), dims);
if (dimsShape.numDimensions() < 2) {
throw new IllegalArgumentException(
@@ -94,17 +95,20 @@ public Operand call(Operand dims, Class type) {
long numCols = dimsShape.size(i);
Shape flatShape = Shape.of(Math.max(numRows, numCols), Math.min(numRows, numCols));
long[] seeds = {seed, 0};
- Operand op =
+ Operand