Skip to content

Commit

Permalink
Optimize spectral_normalization implementation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 722757020
  • Loading branch information
tensorflower-gardener committed Feb 3, 2025
1 parent 3f00c13 commit eaa4052
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions tf_keras/layers/normalization/spectral_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def build(self, input_shape):

def call(self, inputs, training=False):
if training:
self.normalize_weights()
self._update_weights()

output = self.layer(inputs)
return output
Expand All @@ -105,35 +105,42 @@ def compute_output_shape(self, input_shape):
self.layer.compute_output_shape(input_shape).as_list()
)

def _update_weights(self):
weights = self.kernel
vector_u = self.vector_u

kernel_weights, vector_u = tf.cond(
tf.reduce_all(tf.equal(weights, 0)),
lambda: (weights, vector_u),
lambda: self.normalize_weights(),
)
self.kernel.assign(kernel_weights)
self.vector_u.assign(vector_u)

def normalize_weights(self):
"""Generate spectral normalized weights.
This method will update the value of `self.kernel` with the
spectral normalized value, so that the layer is ready for `call()`.
"""

weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
# Initialize vector_v to hint the compiler it always exist.
vector_u = self.vector_u

# check for zeroes weights
if not tf.reduce_all(tf.equal(weights, 0.0)):
for _ in range(self.power_iterations):
vector_v = tf.math.l2_normalize(
tf.matmul(vector_u, weights, transpose_b=True)
)
vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
vector_u = tf.stop_gradient(vector_u)
vector_v = tf.stop_gradient(vector_v)
sigma = tf.matmul(
tf.matmul(vector_v, weights), vector_u, transpose_b=True
)
self.vector_u.assign(tf.cast(vector_u, self.vector_u.dtype))
self.kernel.assign(
tf.cast(
tf.reshape(self.kernel / sigma, self.kernel_shape),
self.kernel.dtype,
)
vector_v = self.vector_u
weights = tf.reshape(self.kernel, [-1, self.kernel_shape[-1]])
for _ in range(self.power_iterations):
vector_v = tf.math.l2_normalize(
tf.matmul(vector_u, weights, transpose_b=True)
)
vector_u = tf.math.l2_normalize(tf.matmul(vector_v, weights))
vector_u = tf.stop_gradient(vector_u)
vector_v = tf.stop_gradient(vector_v)
sigma = tf.matmul(
tf.matmul(vector_v, weights),
vector_u,
transpose_b=True,
)
weights_normalized = tf.reshape(weights / sigma, self.kernel_shape)
return weights_normalized, vector_u

def get_config(self):
config = {"power_iterations": self.power_iterations}
Expand Down

0 comments on commit eaa4052

Please sign in to comment.