From d6dd2e8e16145e73f69664bc81690ac06857319b Mon Sep 17 00:00:00 2001 From: init-22 Date: Tue, 17 Dec 2024 21:50:11 +0530 Subject: [PATCH] fix: resolving pylint issues in custom_tf_addons --- .../imagenet_jax/custom_tf_addons.py | 27 +++++++++---------- .../imagenet_jax/randaugment.py | 4 +-- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py index eda67d226..79aef6791 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/custom_tf_addons.py @@ -6,8 +6,7 @@ """ -import math -from typing import Callable, List, Optional, Union +from typing import List, Optional, Union import numpy as np import tensorflow as tf @@ -48,7 +47,7 @@ def get_ndims(image): return image.get_shape().ndims or tf.rank(image) -def to_4D_image(image): +def to_4d_image(image): """Convert 2/3/4D image to 4D image. Args: @@ -63,7 +62,7 @@ def to_4D_image(image): ]): ndims = image.get_shape().ndims if ndims is None: - return _dynamic_to_4D_image(image) + return _dynamic_to_4d_image(image) elif ndims == 2: return image[None, :, :, None] elif ndims == 3: @@ -72,7 +71,7 @@ def to_4D_image(image): return image -def _dynamic_to_4D_image(image): +def _dynamic_to_4d_image(image): shape = tf.shape(image) original_rank = tf.rank(image) # 4D image => [N, H, W, C] or [N, C, H, W] @@ -91,7 +90,7 @@ def _dynamic_to_4D_image(image): return tf.reshape(image, new_shape) -def from_4D_image(image, ndims): +def from_4d_image(image, ndims): """Convert back to an image with `ndims` rank. Args: @@ -105,7 +104,7 @@ def from_4D_image(image, ndims): [tf.debugging.assert_rank(image, 4, message="`image` must be 4D tensor")]): if isinstance(ndims, tf.Tensor): - return _dynamic_from_4D_image(image, ndims) + return _dynamic_from_4d_image(image, ndims) elif ndims == 2: return tf.squeeze(image, [0, 3]) elif ndims == 3: @@ -114,7 +113,7 @@ def from_4D_image(image, ndims): return image -def _dynamic_from_4D_image(image, original_rank): +def _dynamic_from_4d_image(image, original_rank): shape = tf.shape(image) # 4D image <= [N, H, W, C] or [N, C, H, W] # 3D image <= [1, H, W, C] or [1, C, H, W] @@ -183,7 +182,7 @@ def transform( transforms, name="transforms", dtype=tf.dtypes.float32) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - images = to_4D_image(image_or_images) + images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) if output_shape is None: @@ -217,7 +216,7 @@ def transform( fill_mode=fill_mode.upper(), fill_value=fill_value, ) - return from_4D_image(output, original_ndims) + return from_4d_image(output, original_ndims) def angles_to_projective_transforms( @@ -271,7 +270,7 @@ def angles_to_projective_transforms( ) -def rotate( +def rotate_img( images: TensorLike, angles: TensorLike, interpolation: str = "nearest", @@ -286,7 +285,7 @@ def rotate( `(num_images, num_rows, num_columns, num_channels)` (NHWC), `(num_rows, num_columns, num_channels)` (HWC), or `(num_rows, num_columns)` (HW). - angles: A scalar angle to rotate all images by, or (if `images` has rank 4) + angles: A scalar angle to rotate all images by (if `images` has rank 4) a vector of length num_images, with an angle for each image in the batch. interpolation: Interpolation mode. Supported values: "nearest", @@ -317,7 +316,7 @@ def rotate( image_or_images = tf.convert_to_tensor(images) if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: raise TypeError("Invalid dtype %s." % image_or_images.dtype) - images = to_4D_image(image_or_images) + images = to_4d_image(image_or_images) original_ndims = get_ndims(image_or_images) image_height = tf.cast(tf.shape(images)[1], tf.dtypes.float32)[None] @@ -329,7 +328,7 @@ def rotate( fill_mode=fill_mode, fill_value=fill_value, ) - return from_4D_image(output, original_ndims) + return from_4d_image(output, original_ndims) def translations_to_projective_transforms(translations: TensorLike, diff --git a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py index f3a946245..dd00146cd 100644 --- a/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py +++ b/algorithmic_efficiency/workloads/imagenet_resnet/imagenet_jax/randaugment.py @@ -9,7 +9,7 @@ import tensorflow as tf -from .custom_tf_addons import rotate +from .custom_tf_addons import rotate_img from .custom_tf_addons import transform from .custom_tf_addons import translate @@ -179,7 +179,7 @@ def rotate(image, degrees, replace): # In practice, we should randomize the rotation degrees by flipping # it negatively half the time, but that's done on 'degrees' outside # of the function. - image = rotate(wrap(image), radians) + image = rotate_img(wrap(image), radians) return unwrap(image, replace)