Skip to content

Commit

Permalink
fixing imports
Browse files Browse the repository at this point in the history
  • Loading branch information
znado committed Mar 2, 2022
1 parent 6020fae commit d17e0ef
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from jax import lax
import jax.numpy as jnp
import numpy as np
from workloads.imagenet.workload import ImagenetWorkload

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.imagenet.imagenet_jax import \
input_pipeline
from algorithmic_efficiency.workloads.imagenet.imagenet_jax import models
from algorithmic_efficiency.workloads.imagenet.workload import ImagenetWorkload


_InitState = Tuple[spec.ParameterContainer, spec.ModelAuxiliaryState] # pylint: disable=invalid-name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
import jax.numpy as jnp
import tensorflow as tf
import tensorflow_datasets as tfds
from workloads.mnist.workload import Mnist

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.mnist.workload import Mnist


class _Model(nn.Module):
Expand Down
4 changes: 2 additions & 2 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,12 @@
'imagenet_jax': {
'workload_path':
BASE_WORKLOADS_DIR + 'imagenet/imagenet_jax/workload.py',
'workload_class_name': 'ImagenetWorkload',
'workload_class_name': 'ImagenetJaxWorkload',
},
'imagenet_pytorch': {
'workload_path':
BASE_WORKLOADS_DIR + 'imagenet/imagenet_pytorch/workload.py',
'workload_class_name': 'ImagenetWorkload',
'workload_class_name': 'ImagenetPytorchWorkload',
},
'wmt_jax': {
'workload_path': BASE_WORKLOADS_DIR + 'wmt/wmt_jax/workload.py',
Expand Down

0 comments on commit d17e0ef

Please sign in to comment.