diff --git a/.pylintrc b/.pylintrc index 467c1f9..f9257ae 100644 --- a/.pylintrc +++ b/.pylintrc @@ -273,7 +273,7 @@ generated-members= # Maximum number of characters on a single line. max-line-length=80 -# TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt +# TODO: Direct pylint to exempt # lines made too long by directives to pytype. # Regexp for a line that is allowed to be longer than the limit. diff --git a/swirl_dynamics/lib/diffusion/vivit.py b/swirl_dynamics/lib/diffusion/vivit.py index e9837f9..9903d76 100644 --- a/swirl_dynamics/lib/diffusion/vivit.py +++ b/swirl_dynamics/lib/diffusion/vivit.py @@ -373,7 +373,7 @@ def __call__(self, inputs: Array, *, train: bool) -> Array: name='conv_transpose_temporal_decoder', )(x) - # TODO(lzepedanunez): Use unets.depth_to_space here instead. + # TODO: Use unets.depth_to_space here instead. x = jnp.reshape( x, (batch_size, *self.encoded_shapes, t, h, w, self.features_out) ) @@ -665,7 +665,7 @@ class TransformerBlock(nn.Module): mlp_dim: int num_layers: int num_heads: int - # TODO(lzepedanunez): encapsulate the configurations in its own container. + # TODO: encapsulate the configurations in its own container. attention_config: ml_collections.ConfigDict | None = None dropout_rate: float = 0.1 attention_dropout_rate: float = 0.1 @@ -682,10 +682,10 @@ def __call__(self, inputs: Array, *, train: bool) -> Array: dtype = jax.dtypes.canonicalize_dtype(self.dtype) # Computing positional embeddings. - # TODO(lzepedanunez): Introduce more types of positional encoding. + # TODO: Introduce more types of positional encoding. if self.positional_embedding == 'sinusoidal_3d': batch, num_tokens, hidden_dim = inputs.shape - # TODO(lzepedanunez): change this one to handle non-square domains. + # TODO: change this one to handle non-square domains. height = width = int(np.sqrt(num_tokens // self.temporal_dims)) if height * width * self.temporal_dims != num_tokens: raise ValueError('Input is assumed to be square for sinusoidal init.') diff --git a/swirl_dynamics/lib/diffusion/vivit_diffusion.py b/swirl_dynamics/lib/diffusion/vivit_diffusion.py index 90f7278..28198c6 100644 --- a/swirl_dynamics/lib/diffusion/vivit_diffusion.py +++ b/swirl_dynamics/lib/diffusion/vivit_diffusion.py @@ -291,7 +291,7 @@ def __call__(self, inputs: Array, emb: Array, *, train: bool) -> Array: dtype = jax.dtypes.canonicalize_dtype(self.dtype) # Choosing the type of embedding. - # TODO(lzepedanunez): add more embeddings in here. + # TODO: add more embeddings in here. if self.positional_embedding == 'sinusoidal_3d': batch, num_tokens, hidden_dim = inputs.shape height = width = int(np.sqrt(num_tokens // self.temporal_dims)) @@ -324,7 +324,7 @@ def __call__(self, inputs: Array, emb: Array, *, train: bool) -> Array: self.attention_config.get('attention_kernel_init_method', 'xavier')], # pytype: disable=attribute-error temporal_dims=self.temporal_dims) - # TODO(lzepedanunez): implement factorized_dot_product_attention. + # TODO: implement factorized_dot_product_attention. else: raise ValueError(f'Unknown attention type {self.attention_config.type}') # pytype: disable=attribute-error diff --git a/swirl_dynamics/lib/networks/cycle_gan.py b/swirl_dynamics/lib/networks/cycle_gan.py index 147c796..2f4dcd4 100644 --- a/swirl_dynamics/lib/networks/cycle_gan.py +++ b/swirl_dynamics/lib/networks/cycle_gan.py @@ -199,7 +199,7 @@ class Generator(nn.Module): use_skips: bool = True use_global_skip: bool = True dtype: jnp.dtype = jnp.float32 - padding: str = "CIRCULAR" # TODO(lzepedanunez): Add one adapted for ERA5. + padding: str = "CIRCULAR" # TODO: Add one adapted for ERA5. padding_transpose: str = "CIRCULAR" use_weight_global_skip: bool = False weight_skip: bool = False @@ -288,7 +288,7 @@ def __call__(self, x: Array, is_training: bool) -> Array: )(x) # Use a transformer core. - # TODO(lzepedanunez) add a conformer model. + # TODO add a conformer model. if self.use_attention: b, *hw, c = x.shape # Adding positional encoding. @@ -339,7 +339,7 @@ def __call__(self, x: Array, is_training: bool) -> Array: )(x) elif self.upsample_mode == "deconv": - # TODO(lzepedanunez): use channel unrolling for the upsampling. + # TODO: use channel unrolling for the upsampling. x = nn.ConvTranspose( features=(self.ngf * mult) // 2, kernel_size=self.kernel_size_upsampling, diff --git a/swirl_dynamics/lib/networks/nonlinear_fourier.py b/swirl_dynamics/lib/networks/nonlinear_fourier.py index 900a529..6c8e190 100644 --- a/swirl_dynamics/lib/networks/nonlinear_fourier.py +++ b/swirl_dynamics/lib/networks/nonlinear_fourier.py @@ -245,7 +245,7 @@ def __call__(self, inputs: Array) -> Array: # shape : (2, num_freqs, 2) for sin-cos, \omega, and x-y. y = omega * (x_i.reshape((1, 1, 2)) + a) - # TODO(lzepedanunez): create a funcion that creates the periodic features. + # TODO: create a funcion that creates the periodic features. # Applying the trigonometric functions, which can be written as: # [[1, 1], # [sin(ω₁ x), sin(ω₁ y)], diff --git a/swirl_dynamics/lib/networks/rational_networks.py b/swirl_dynamics/lib/networks/rational_networks.py index 02dfcfc..6b2d617 100644 --- a/swirl_dynamics/lib/networks/rational_networks.py +++ b/swirl_dynamics/lib/networks/rational_networks.py @@ -162,7 +162,7 @@ class RationalMLP(nn.Module): dtype: Any = jnp.float32 multi_rational: bool = False use_bias: bool = True - # TODO(lzepedanunez): add precision flag to have more granular control + # TODO: add precision flag to have more granular control @nn.compact def __call__(self, inputs: Array) -> Array: diff --git a/swirl_dynamics/projects/debiasing/cycle_gan/models.py b/swirl_dynamics/projects/debiasing/cycle_gan/models.py index ccff3d5..45d851b 100644 --- a/swirl_dynamics/projects/debiasing/cycle_gan/models.py +++ b/swirl_dynamics/projects/debiasing/cycle_gan/models.py @@ -175,7 +175,7 @@ def run_generator_forward( A tuple containing the generated samples. """ - # TODO(lzepedanunez): perhaps use dictionaries instead of positional tuples. + # TODO: perhaps use dictionaries instead of positional tuples. params_gen_a2b = params_gen[0] params_gen_b2a = params_gen[1] @@ -480,7 +480,7 @@ def loss_fn( to be real data. """ # Split the States. - # TODO(lzepedanunez): specify how to split the parameters. + # TODO: specify how to split the parameters. params_gen_a2b = params[0] params_gen_b2a = params[1] diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py b/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py index 07cb1e8..34fcba4 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py @@ -125,7 +125,7 @@ def create_loader_from_hdf5( mean and std stats (if normalize=True, else dict contains NoneType values). """ - # TODO(lzepedanunez): create the data arrays following a similar convention. + # TODO: create the data arrays following a similar convention. snapshots = hdf5_utils.read_single_array( dataset_path, f"{split}/u", diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/main.py b/swirl_dynamics/projects/debiasing/rectified_flow/main.py index fad9be9..56db5bf 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/main.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/main.py @@ -119,7 +119,7 @@ def main(argv): ) model = models.ReFlowModel( - # TODO(lzepedanunez): clean this part. + # TODO: clean this part. input_shape=( config.input_shapes[0][1] // config.spatial_downsample_factor[0], config.input_shapes[0][2] // config.spatial_downsample_factor[0], @@ -157,7 +157,7 @@ def main(argv): base_dir=workdir, options=ckpt_options, ), - # TODO(lzepedanunez) add a plot callback. + # TODO add a plot callback. ), ) diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/models.py b/swirl_dynamics/projects/debiasing/rectified_flow/models.py index 9a6a2a6..207d997 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/models.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/models.py @@ -99,7 +99,7 @@ class ReFlowModel(models.BaseModel): num_eval_time_levels: ClassVar[int] = 10 def initialize(self, rng: Array): - # TODO(lzepedanunez): Add a dtype object to ensure consistency of types. + # TODO: Add a dtype object to ensure consistency of types. x = jnp.ones((1,) + self.input_shape) return self.flow_model.init( diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/trainers.py b/swirl_dynamics/projects/debiasing/rectified_flow/trainers.py index 748101a..380655d 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/trainers.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/trainers.py @@ -98,7 +98,7 @@ class DistributedReFlowTrainer( ): """Multi-device trainer for rectified flow models.""" - # TODO(lzepedanunez): Write a test for this trainer. + # TODO: Write a test for this trainer. # MRO: ReFlowTrainer > BasicDistributedTrainer > BasicTrainer ... diff --git a/swirl_dynamics/projects/ergodic/choices.py b/swirl_dynamics/projects/ergodic/choices.py index 0b8e13c..f92c4e8 100644 --- a/swirl_dynamics/projects/ergodic/choices.py +++ b/swirl_dynamics/projects/ergodic/choices.py @@ -67,7 +67,7 @@ def dispatch( Returns: ScanOdeSolver | MultiStepScanOdeSolver """ - # TODO(yairschiff): Profile if the moveaxis call required here introduces a + # TODO: Profile if the moveaxis call required here introduces a # bottleneck return { "ExplicitEuler": ode.ExplicitEuler(time_axis_pos=1), diff --git a/swirl_dynamics/projects/ergodic/configs/ks_1d.py b/swirl_dynamics/projects/ergodic/configs/ks_1d.py index 75dfda2..048fa75 100644 --- a/swirl_dynamics/projects/ergodic/configs/ks_1d.py +++ b/swirl_dynamics/projects/ergodic/configs/ks_1d.py @@ -53,7 +53,7 @@ def get_config(): # Model params config.model = 'PeriodicConvNetModel' # 'Fno' - # TODO(yairschiff): Split CNN and FNO into separate configs + # TODO: Split CNN and FNO into separate configs ########### PeriodicConvNetModel ################ config.latent_dim = 48 config.num_levels = 4 @@ -120,7 +120,7 @@ def skip( return False -# TODO(yairschiff): Refactor sweeps and experiment definition to use gin. +# TODO: Refactor sweeps and experiment definition to use gin. # use option --sweep=False in the command line to avoid sweeping def sweep(add): """Define param sweep.""" @@ -165,7 +165,7 @@ def sweep(add): ) -# TODO(yairschiff): Ablation! +# TODO: Ablation! # def sweep(add): # """Define param sweep.""" # # pylint: disable=line-too-long diff --git a/swirl_dynamics/projects/ergodic/configs/ks_1d_dist.py b/swirl_dynamics/projects/ergodic/configs/ks_1d_dist.py index 3c9106c..325e306 100644 --- a/swirl_dynamics/projects/ergodic/configs/ks_1d_dist.py +++ b/swirl_dynamics/projects/ergodic/configs/ks_1d_dist.py @@ -113,7 +113,7 @@ def skip( return False -# TODO(yairschiff): Refactor sweeps and experiment definition to use gin. +# TODO: Refactor sweeps and experiment definition to use gin. # use option --sweep=False in the command line to avoid sweeping def sweep(add): """Define param sweep.""" diff --git a/swirl_dynamics/projects/ergodic/configs/lorenz63.py b/swirl_dynamics/projects/ergodic/configs/lorenz63.py index a73012d..92f4916 100644 --- a/swirl_dynamics/projects/ergodic/configs/lorenz63.py +++ b/swirl_dynamics/projects/ergodic/configs/lorenz63.py @@ -96,7 +96,7 @@ def skip( return False -# TODO(yairschiff): Refactor sweeps and experiment definition to use gin. +# TODO: Refactor sweeps and experiment definition to use gin. def sweep(add): """Define param sweep.""" # pylint: disable=line-too-long diff --git a/swirl_dynamics/projects/ergodic/configs/ns_2d.py b/swirl_dynamics/projects/ergodic/configs/ns_2d.py index 8352681..5c99f8e 100644 --- a/swirl_dynamics/projects/ergodic/configs/ns_2d.py +++ b/swirl_dynamics/projects/ergodic/configs/ns_2d.py @@ -54,7 +54,7 @@ def get_config(): config.noise_level = 0.0 # Model params - # TODO(yairschiff): Split CNN and FNO into separate configs + # TODO: Split CNN and FNO into separate configs config.model = 'PeriodicConvNetModel' # 'Fno' 'Fno2d' ########### PeriodicConvNetModel ################ config.latent_dim = 16 @@ -131,7 +131,7 @@ def skip( # pylint: disable=line-too-long -# TODO(yairschiff): Refactor sweeps and experiment definition to use gin. +# TODO: Refactor sweeps and experiment definition to use gin. # use option --sweep=False in the command line to avoid sweeping def sweep(add): """Define param sweep.""" diff --git a/swirl_dynamics/projects/ergodic/configs/ns_2d_dist.py b/swirl_dynamics/projects/ergodic/configs/ns_2d_dist.py index 1795e0a..da19d83 100644 --- a/swirl_dynamics/projects/ergodic/configs/ns_2d_dist.py +++ b/swirl_dynamics/projects/ergodic/configs/ns_2d_dist.py @@ -89,7 +89,7 @@ def get_config(): return config -# TODO(yairschiff): Refactor sweeps and experiment definition to use gin. +# TODO: Refactor sweeps and experiment definition to use gin. def sweep(add): """Define param sweep.""" for seed in [42]: diff --git a/swirl_dynamics/projects/ergodic/main.py b/swirl_dynamics/projects/ergodic/main.py index 508d5e7..a57ed5e 100644 --- a/swirl_dynamics/projects/ergodic/main.py +++ b/swirl_dynamics/projects/ergodic/main.py @@ -13,7 +13,7 @@ # limitations under the License. r"""The main entry point for running training loops.""" -# TODO(yairschiff): Consider enabling float64 for Lorenz63 experiment +# TODO: Consider enabling float64 for Lorenz63 experiment import json from os import path as osp @@ -82,7 +82,7 @@ def main(argv): elif experiment == choices.Experiment.NS_2D: fig_callback_cls = ns_2d.NS2dPlotFigures - # TODO(yairschiff): This state dim is temporary for FNO data, should be 256 + # TODO: This state dim is temporary for FNO data, should be 256 state_dims = ( 64 // config.spatial_downsample_factor, 64 // config.spatial_downsample_factor, diff --git a/swirl_dynamics/projects/ergodic/measure_distances.py b/swirl_dynamics/projects/ergodic/measure_distances.py index b36947a..31970a0 100644 --- a/swirl_dynamics/projects/ergodic/measure_distances.py +++ b/swirl_dynamics/projects/ergodic/measure_distances.py @@ -70,7 +70,7 @@ def mmd(x: Array, y: Array) -> Array: xx, yy, xy = (jnp.zeros_like(xx), jnp.zeros_like(xx), jnp.zeros_like(xx)) # Multiscale - # TODO(yairschiff): We may need to experiment with these bandwidths to have + # TODO: We may need to experiment with these bandwidths to have # MMD loss better distinguish distributions, especially for high dim data bandwidth_range = [0.2, 0.5, 0.9, 1.3] for a in bandwidth_range: @@ -78,7 +78,7 @@ def mmd(x: Array, y: Array) -> Array: yy += a**2 * (a**2 + dyy) ** -1 xy += a**2 * (a**2 + dxy) ** -1 - # TODO(yairschiff): We may want to use jnp.sqrt(...) here; see: + # TODO: We may want to use jnp.sqrt(...) here; see: # https://arxiv.org/abs/1502.02761 return jnp.mean(xx + yy - 2.0 * xy) diff --git a/swirl_dynamics/projects/ergodic/stable_ar.py b/swirl_dynamics/projects/ergodic/stable_ar.py index 9d8cf40..f0f79e8 100644 --- a/swirl_dynamics/projects/ergodic/stable_ar.py +++ b/swirl_dynamics/projects/ergodic/stable_ar.py @@ -71,7 +71,7 @@ def __post_init__(self): self.pred_integrator = functools.partial( pred_integrator, ode.nn_module_to_dynamics(self.conf.dynamics_model) ) - # TODO(lzepedanunez): check if this is compatible with distributed training. + # TODO: check if this is compatible with distributed training. self.vmapped_measure_dist = jax.vmap(self.conf.measure_dist, in_axes=(1, 1)) def initialize(self, rng): @@ -97,7 +97,7 @@ def loss_fn( tspan = batch["tspan"].reshape((-1,)) rollout_weight = batch["rollout_weight"].reshape((-1,)) - # TODO(lzepedanunez): implement the logic in the Neural Markov paper. + # TODO: implement the logic in the Neural Markov paper. if self.conf.add_noise: noise = self.conf.noise_level + jax.random.normal(rng, x0.shape) x0 += noise @@ -131,7 +131,7 @@ def loss_fn( # Compare to true trajectory last step. if self.conf.use_sobolev_norm: - # TODO(yairschiff): Rollout weighting not implemented for this case! + # TODO: Rollout weighting not implemented for this case! # The spatial dimension is the length of the shape minus 2, # which accounts for the batch, frame, and channel dimensions. dim = len(pred.shape) - 2 @@ -163,9 +163,9 @@ def loss_fn( ) # Compare to full reference trajectory. - # TODO(lzepedanunez): this is code is repeated. + # TODO: this is code is repeated. if self.conf.use_sobolev_norm: - # TODO(yairschiff): Rollout weighting not implemented for this case! + # TODO: Rollout weighting not implemented for this case! dim = len(pred.shape) - 3 l2 = ergodic_utils.sobolev_norm( pred - true[:, 1:, ...], @@ -221,7 +221,7 @@ def eval_fn( pred_trajs *= self.conf.normalize_stats["std"] pred_trajs += self.conf.normalize_stats["mean"] - # TODO(lzepedanunez): this only computes the local sinkhorn distance. + # TODO: this only computes the local sinkhorn distance. sd = measure_distances.sinkhorn_div( pred_trajs[:, -1, ...], trajs[:, -1, ...] ) @@ -356,7 +356,7 @@ def preprocess_train_batch( num_time_steps += self.conf.num_rollout_steps + 1 else: num_time_steps = self.conf.num_rollout_steps + 1 - # TODO(yairschiff): Should we remove this random sampling? + # TODO: Should we remove this random sampling? if self.conf.use_pushfwd and num_time_steps > 2: num_time_steps = jax.random.randint( rng, (1,), minval=2, maxval=num_time_steps + 1 @@ -478,7 +478,7 @@ def preprocess_train_batch( num_time_steps += self.conf.num_rollout_steps + 1 else: num_time_steps = self.conf.num_rollout_steps + 1 - # TODO(yairschiff): Should we remove this random sampling? + # TODO: Should we remove this random sampling? if self.conf.use_pushfwd and num_time_steps > 2: num_time_steps = jax.random.randint( rng, (1,), minval=2, maxval=num_time_steps + 1 diff --git a/swirl_dynamics/projects/ergodic/utils.py b/swirl_dynamics/projects/ergodic/utils.py index 59b5eb9..79d93fa 100644 --- a/swirl_dynamics/projects/ergodic/utils.py +++ b/swirl_dynamics/projects/ergodic/utils.py @@ -32,7 +32,7 @@ DynamicsFn = Callable[[Array, Array, PyTree], Array] -# TODO(yairschiff): Move this method to swirl_dynamics.data.utils +# TODO: Move this method to swirl_dynamics.data.utils def generate_data_from_known_dynamcics( integrator: ode.ScanOdeSolver, dynamics: DynamicsFn, @@ -47,7 +47,7 @@ def generate_data_from_known_dynamcics( return integrator(dynamics, x0, tspan, {})[warmup:] -# TODO(yairschiff): Move this method to swirl_dynamics.data.utils +# TODO: Move this method to swirl_dynamics.data.utils def create_loader_from_hdf5( num_time_steps: int, time_stride: int, @@ -180,7 +180,7 @@ def create_loader_from_tfds( """Load pre-computed trajectories dumped to hdf5 file. This loader has fewer options that the one from hdf5, in particular, it has - no normalization. TODO(lzepedanunez): Add normalization. + no normalization. TODO: Add normalization. Arguments: num_time_steps: Number of time steps to include in each trajectory. @@ -241,7 +241,7 @@ def create_loader_from_tfds( return loader, {"mean": None, "std": None} -# TODO(lzepedanunez): find a better place for this function and refactor with +# TODO: find a better place for this function and refactor with # vmap. def sobolev_norm( u: Array, s: int = 1, dim: int = 2, length: float = 1.0 diff --git a/swirl_dynamics/projects/probabilistic_diffusion/models.py b/swirl_dynamics/projects/probabilistic_diffusion/models.py index 73a95db..67be4dc 100644 --- a/swirl_dynamics/projects/probabilistic_diffusion/models.py +++ b/swirl_dynamics/projects/probabilistic_diffusion/models.py @@ -137,7 +137,7 @@ def loss_fn( sigma=sigma, cond=cond, is_training=True, - rngs={"dropout": rng3}, # TODO(lzepedanunez): refactor this. + rngs={"dropout": rng3}, # TODO: refactor this. ) loss = jnp.mean(vmapped_mult(weights, jnp.square(denoised - batch["x"]))) metric = dict(loss=loss) diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov.py b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov.py index 76b5f02..1109494 100644 --- a/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov.py +++ b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov.py @@ -28,7 +28,7 @@ def get_config(): config = ml_collections.ConfigDict() # Model. - # TODO(lzepedanunez): undo all the nested dictionaries. + # TODO: undo all the nested dictionaries. config.model_name = 'ViViT Denoiser' config.model = ml_collections.ConfigDict() config.model.hidden_size = 384 # 192 # 768 @@ -45,7 +45,7 @@ def get_config(): config.save_interval_steps = 1000 config.max_checkpoints_to_keep = 10 - # TODO(lzepedanunez): create custom data structures. + # TODO: create custom data structures. config.model.temporal_encoding_config = ml_collections.ConfigDict() config.model.temporal_encoding_config.method = '3d_conv' # pylint: disable=line-too-long @@ -53,7 +53,7 @@ def get_config(): # pylint: enable=line-too-long config.model.positional_embedding = 'sinusoidal_3d' # 'sinusoidal_3d' - # TODO(lzepedanunez): patches doesn't need to be a dictionary. + # TODO: patches doesn't need to be a dictionary. config.model.patches = ml_collections.ConfigDict() config.model.patches.size = (4, 4, 4) # (time, height, width) diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_med_res.py b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_med_res.py index cf445ee..c8dfb59 100644 --- a/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_med_res.py +++ b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_med_res.py @@ -28,7 +28,7 @@ def get_config(): config = ml_collections.ConfigDict() # Model. - # TODO(lzepedanunez): Undo all the nested dictionaries. + # TODO: Undo all the nested dictionaries. config.model_name = 'ViViT Denoiser' config.model = ml_collections.ConfigDict() config.model.hidden_size = 576 @@ -45,7 +45,7 @@ def get_config(): config.save_interval_steps = 1000 config.max_checkpoints_to_keep = 10 - # TODO(lzepedanunez): create custom data structures. + # TODO: create custom data structures. config.model.temporal_encoding_config = ml_collections.ConfigDict() config.model.temporal_encoding_config.method = '3d_conv' # pylint: disable=line-too-long @@ -53,7 +53,7 @@ def get_config(): # pylint: enable=line-too-long config.model.positional_embedding = 'sinusoidal_3d' # 'sinusoidal_3d' - # TODO(lzepedanunez): patches doesn't need to be a dictionary. + # TODO: patches doesn't need to be a dictionary. config.model.patches = ml_collections.ConfigDict() config.model.patches.size = (4, 4, 4) # (time, height, width) diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_transformer.py b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_transformer.py index ba54825..2fe15c9 100644 --- a/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_transformer.py +++ b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_transformer.py @@ -28,7 +28,7 @@ def get_config(): config = ml_collections.ConfigDict() # Model. - # TODO(lzepedanunez) undo all the nested dictionaries. + # TODO undo all the nested dictionaries. config.model_name = 'ViViT Denoiser' config.model = ml_collections.ConfigDict() config.model.hidden_size = 192 @@ -45,7 +45,7 @@ def get_config(): config.save_interval_steps = 1000 config.max_checkpoints_to_keep = 10 - # TODO(lzepedanunez): create custom data structures. + # TODO: create custom data structures. config.model.temporal_encoding_config = ml_collections.ConfigDict() config.model.temporal_encoding_config.method = '3d_conv' # pylint: disable=line-too-long @@ -54,7 +54,7 @@ def get_config(): # pylint: enable=line-too-long config.model.positional_embedding = 'none' # 'sinusoidal_3d' - # TODO(lzepedanunez): patches doesn't need to be a dictionary. + # TODO: patches doesn't need to be a dictionary. config.model.patches = ml_collections.ConfigDict() config.model.patches.size = (4, 4, 4) # (time, height, width) diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/data_utils.py b/swirl_dynamics/projects/spatiotemporal_modeling/data_utils.py index ffd5932..c1ce7af 100644 --- a/swirl_dynamics/projects/spatiotemporal_modeling/data_utils.py +++ b/swirl_dynamics/projects/spatiotemporal_modeling/data_utils.py @@ -96,7 +96,7 @@ def create_loader_from_hdf5( data_for_stats = hdf5_utils.read_single_array(dataset_path, "train/u") else: data_for_stats = snapshots - # TODO(lzepedanunez): For the sake of memory perform this in CPU. + # TODO: For the sake of memory perform this in CPU. if use_time_normalization: num_trajs, num_frames, nx, ny, d = data_for_stats.shape num_segments = num_frames // num_time_steps diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/main.py b/swirl_dynamics/projects/spatiotemporal_modeling/main.py index 9b91871..b525bf4 100644 --- a/swirl_dynamics/projects/spatiotemporal_modeling/main.py +++ b/swirl_dynamics/projects/spatiotemporal_modeling/main.py @@ -184,7 +184,7 @@ def main(argv): base_dir=workdir, options=ckpt_options, ), - # TODO(lzepedanunez) add a plot callback. + # TODO add a plot callback. ), ) diff --git a/swirl_dynamics/templates/train.py b/swirl_dynamics/templates/train.py index 1d09e5e..23d479b 100644 --- a/swirl_dynamics/templates/train.py +++ b/swirl_dynamics/templates/train.py @@ -27,7 +27,7 @@ filesys = epath.backend.tf_backend -# TODO(wanzy): package parameters into logical groupings (see cl/497196196) +# TODO: package parameters into logical groupings (see cl/497196196) def run( *, train_dataloader: Iterable[Any], diff --git a/swirl_dynamics/templates/train_states.py b/swirl_dynamics/templates/train_states.py index bbd81d7..6412db2 100644 --- a/swirl_dynamics/templates/train_states.py +++ b/swirl_dynamics/templates/train_states.py @@ -29,7 +29,7 @@ import optax from orbax import checkpoint -# TODO(wanzy): use typing.Self after python 3.11 (PEP 673) +# TODO: use typing.Self after python 3.11 (PEP 673) TState = TypeVar("TState", bound="TrainState") EMPTY_DICT = flax.core.freeze({})