From eaa4052e6fbeb2c0bebb6ceb39f8f00ed0cf4997 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Mon, 3 Feb 2025 12:31:44 -0800 Subject: [PATCH] Optimize `spectral_normalization` implementation. PiperOrigin-RevId: 722757020 --- .../normalization/spectral_normalization.py | 51 +++++++++++-------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/tf_keras/layers/normalization/spectral_normalization.py b/tf_keras/layers/normalization/spectral_normalization.py index 79a9df05a..17d62b1da 100644 --- a/tf_keras/layers/normalization/spectral_normalization.py +++ b/tf_keras/layers/normalization/spectral_normalization.py @@ -95,7 +95,7 @@ def build(self, input_shape): def call(self, inputs, training=False): if training: - self.normalize_weights() + self._update_weights() output = self.layer(inputs) return output @@ -105,35 +105,42 @@ def compute_output_shape(self, input_shape): self.layer.compute_output_shape(input_shape).as_list() ) + def _update_weights(self): + weights = self.kernel + vector_u = self.vector_u + + kernel_weights, vector_u = tf.cond( + tf.reduce_all(tf.equal(weights, 0)), + lambda: (weights, vector_u), + lambda: self.normalize_weights(), + ) + self.kernel.assign(kernel_weights) + self.vector_u.assign(vector_u) + def normalize_weights(self): """Generate spectral normalized weights. This method will update the value of `self.kernel` with the spectral normalized value, so that the layer is ready for `call()`. """ - - weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]]) + # Initialize vector_v to hint the compiler it always exist. vector_u = self.vector_u - - # check for zeroes weights - if not tf.reduce_all(tf.equal(weights, 0.0)): - for _ in range(self.power_iterations): - vector_v = tf.math.l2_normalize( - tf.matmul(vector_u, weights, transpose_b=True) - ) - vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights)) - vector_u = tf.stop_gradient(vector_u) - vector_v = tf.stop_gradient(vector_v) - sigma = tf.matmul( - tf.matmul(vector_v, weights), vector_u, transpose_b=True - ) - self.vector_u.assign(tf.cast(vector_u, self.vector_u.dtype)) - self.kernel.assign( - tf.cast( - tf.reshape(self.kernel / sigma, self.kernel_shape), - self.kernel.dtype, - ) + vector_v = self.vector_u + weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]]) + for _ in range(self.power_iterations): + vector_v = tf.math.l2_normalize( + tf.matmul(vector_u, weights, transpose_b=True) ) + vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights)) + vector_u = tf.stop_gradient(vector_u) + vector_v = tf.stop_gradient(vector_v) + sigma = tf.matmul( + tf.matmul(vector_v, weights), + vector_u, + transpose_b=True, + ) + weights_normalized = tf.reshape(weights / sigma, self.kernel_shape) + return weights_normalized, vector_u def get_config(self): config = {"power_iterations": self.power_iterations}