Skip to content

Commit

Permalink
fixing lint except yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
znado committed May 11, 2022
1 parent 63a3810 commit 56dbd51
Show file tree
Hide file tree
Showing 10 changed files with 25 additions and 25 deletions.
1 change: 1 addition & 0 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ jobs:
- name: Run pytest
run: |
pytest -vx tests/version_test.py
pytest -vx tests/reference_submission_tests.py
pytest -vx tests/workloads/imagenet/imagenet_jax/workload_test.py
pytest -vx tests/test_num_params.py
pytest -vx tests/test_param_shapes.py
1 change: 1 addition & 0 deletions .github/workflows/linting.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ jobs:
run: |
pylint algorithmic_efficiency
pylint baselines
pylint reference_submissions
pylint submission_runner.py
pylint tests
Expand Down
2 changes: 2 additions & 0 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
import abc
import enum
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union

from absl import logging


class LossType(enum.Enum):
SOFTMAX_CROSS_ENTROPY = 0
SIGMOID_CROSS_ENTROPY = 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""ImageNet workload implemented in PyTorch."""

import contextlib
import os
import math
import os
from typing import Tuple

import torch
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from algorithmic_efficiency.workloads.librispeech.librispeech_pytorch import \
models


device = torch.device(
"cuda:0" if torch.cuda.is_available() else "cpu")

Expand Down
6 changes: 4 additions & 2 deletions algorithmic_efficiency/workloads/wmt/wmt_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import jax
import jax.numpy as jnp
import numpy as np
import tensorflow as tf

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.wmt import bleu
Expand Down Expand Up @@ -89,7 +88,10 @@ def compute_weighted_cross_entropy(self,
((self._vocab_size - 1) * low_confidence *
jnp.log(low_confidence + 1e-20)))
soft_targets = common_utils.onehot(
targets, self._vocab_size, on_value=confidence, off_value=low_confidence)
targets,
self._vocab_size,
on_value=confidence,
off_value=low_confidence)

loss = -jnp.sum(soft_targets * nn.log_softmax(logits), axis=-1)
loss = loss - normalizing_constant
Expand Down
3 changes: 2 additions & 1 deletion reference_submissions/imagenet/imagenet_jax/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def cosine_decay(lr, step, total_steps):
return mult * lr


def create_learning_rate_fn(hparams: spec.Hyperparameters, steps_per_epoch: int):
def create_learning_rate_fn(
hparams: spec.Hyperparameters, steps_per_epoch: int):
"""Create learning rate schedule."""
base_learning_rate = hparams.learning_rate * get_batch_size('imagenet') / 256.
warmup_fn = optax.linear_schedule(
Expand Down
3 changes: 2 additions & 1 deletion reference_submissions/mnist/mnist_pytorch/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def update_params(
rng=rng,
update_batch_norm=True)

loss = workload.loss_fn(label_batch=batch['targets'], logits_batch=output).mean()
loss = workload.loss_fn(
label_batch=batch['targets'], logits_batch=output).mean()

loss.backward()
optimizer_state['optimizer'].step()
Expand Down
2 changes: 0 additions & 2 deletions reference_submissions/wmt/wmt_jax/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@

from flax import jax_utils
from flax import linen as nn
from flax import optim
from flax.training import common_utils
import jax
import jax.numpy as jnp
import numpy as np
import optax

from algorithmic_efficiency import spec
Expand Down
29 changes: 12 additions & 17 deletions tests/reference_submission_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,18 +220,18 @@ def _test_submission(workload_name,
hyperparameters,
global_step,
data_select_rng)
# _, model_params, model_state = update_params(
# workload=workload,
# current_param_container=model_params,
# current_params_types=workload.model_params_types,
# model_state=model_state,
# hyperparameters=hyperparameters,
# batch=batch,
# loss_type=workload.loss_type,
# optimizer_state=optimizer_state,
# eval_results=[],
# global_step=global_step,
# rng=update_rng)
_, model_params, model_state = update_params(
workload=workload,
current_param_container=model_params,
current_params_types=workload.model_params_types,
model_state=model_state,
hyperparameters=hyperparameters,
batch=batch,
loss_type=workload.loss_type,
optimizer_state=optimizer_state,
eval_results=[],
global_step=global_step,
rng=update_rng)
eval_result = workload.eval_model(global_batch_size,
model_params,
model_state,
Expand Down Expand Up @@ -261,11 +261,6 @@ def test_submission(self):
for framework in ['jax', 'pytorch']:
submission_dir = f'{workload_dir}/{workload_name}_{framework}'
if os.path.exists(submission_dir):
# # DO NOT SUBMIT
# if 'mnist' in submission_dir or 'imagenet' in submission_dir or 'librispeech' in submission_dir or 'ogbg' in submission_dir:
# continue
# if not ('librispeech' in submission_dir):
# continue
submission_path = (f'reference_submissions/{workload_name}/'
f'{workload_name}_{framework}/submission.py')
logging.info(f'========= Testing {workload_name} in {framework}.')
Expand Down

0 comments on commit 56dbd51

Please sign in to comment.