Skip to content

Commit

Permalink
Merge branch 'main' into ref-algos-rename
Browse files Browse the repository at this point in the history
  • Loading branch information
znado committed Oct 12, 2022
2 parents 7b664c6 + 018b2f0 commit 1177e36
Show file tree
Hide file tree
Showing 38 changed files with 128 additions and 101 deletions.
18 changes: 4 additions & 14 deletions algorithmic_efficiency/logger_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,13 @@
import psutil

from algorithmic_efficiency import spec
from algorithmic_efficiency.pytorch_utils import pytorch_setup

try:
import wandb # pylint: disable=g-import-not-at-top
except ModuleNotFoundError:
logging.exception('Unable to import wandb.')
wandb = None

USE_PYTORCH_DDP, RANK, DEVICE, N_GPUS = pytorch_setup()


def _get_utilization() -> dict:
util_data = {}
Expand Down Expand Up @@ -141,14 +138,6 @@ def _get_cpu_model_name() -> str:
output)[0].split('Model name:')[1].strip()


def _get_os_package_list() -> str:
return subprocess.check_output(['dpkg', '-l']).decode('ascii').strip()


def _get_pip_package_list() -> str:
return subprocess.check_output(['pip', 'freeze']).decode('ascii').strip()


def _is_primitive_type(item: Any) -> bool:
primitive = (float, int, str, bool)
return isinstance(item, primitive)
Expand Down Expand Up @@ -199,10 +188,11 @@ def __init__(self,
configs: Optional[flags.FLAGS] = None) -> None:
self._measurements = {}
self._csv_path = csv_path
self.use_wandb = configs.use_wandb

if events_dir:
self._tb_metric_writer = metric_writers.create_default_writer(events_dir)
if wandb is not None:
if wandb is not None and self.use_wandb:
wandb.init(
dir=events_dir, tags=[flags.FLAGS.workload, flags.FLAGS.framework])
wandb.config.update(configs)
Expand All @@ -228,11 +218,11 @@ def append_scalar_metrics(self, metrics: dict, global_step: int) -> None:
step=int(metrics['global_step']), scalars=metrics)
self._tb_metric_writer.flush()

if wandb is not None:
if wandb is not None and self.use_wandb:
wandb.log(metrics)

def finish(self) -> None:
if wandb is not None:
if wandb is not None and self.use_wandb:
wandb.finish()


Expand Down
2 changes: 1 addition & 1 deletion algorithmic_efficiency/param_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def jax_param_types(param_shapes, parent_name=''):
# Note that this is exact equality, not contained in, because
# flax.linen.Embed names the embedding parameter "embedding"
# https://github.com/google/flax/blob/main/flax/linen/linear.py#L604.
elif name == 'embedding':
elif 'embedding' in name:
param_types_dict[name] = spec.ParameterType.EMBEDDING
else:
param_types_dict[name] = spec.ParameterType.WEIGHT
Expand Down
7 changes: 4 additions & 3 deletions algorithmic_efficiency/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,10 @@ def __init__(self, shape_tuple):

class Workload(metaclass=abc.ABCMeta):

_param_shapes: Optional[ParameterShapeTree] = None
_param_types: Optional[ParameterTypeTree] = None
_eval_iters: dict = {}
def __init__(self) -> None:
self._param_shapes: Optional[ParameterShapeTree] = None
self._param_types: Optional[ParameterTypeTree] = None
self._eval_iters: Dict[str, Iterator] = {}

@abc.abstractmethod
def has_reached_goal(self, eval_result: float) -> bool:
Expand Down
3 changes: 3 additions & 0 deletions algorithmic_efficiency/workloads/cifar/cifar_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
params = jax_utils.replicate(params)
return params, model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'

@functools.partial(
jax.pmap,
axis_name='batch',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,8 @@ def _build_dataset(self,

return dataloader

# Return whether or not a key in spec.ParameterContainer is the output layer
# parameters.
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
pass
return param_key in ['fc.weight', 'fc.bias']

def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
torch.random.manual_seed(rng[0])
Expand Down
5 changes: 0 additions & 5 deletions algorithmic_efficiency/workloads/cifar/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,3 @@ def step_hint(self) -> int:
# workload, but for completeness we provide the number of steps for 100
# epochs at batch size 1024.
return 4883

# Return whether or not a key in spec.ParameterTree is the output layer
# parameters.
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
self._param_types = param_utils.jax_param_types(self._param_shapes)
return jax_utils.replicate(initial_params), None

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_4'

def model_fn(
self,
params: spec.ParameterContainer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
model = torch.nn.DataParallel(model)
return model, None

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['top_mlp.4.weight', 'top_mlp.4.bias']

def model_fn(
self,
params: spec.ParameterContainer,
Expand Down
18 changes: 6 additions & 12 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import math
from typing import Dict, Optional
from typing import Dict, Optional, Tuple

from absl import flags
import jax
Expand All @@ -13,12 +13,11 @@
class BaseCriteo1TbDlrmSmallWorkload(spec.Workload):
"""Criteo1tb workload."""

def __init__(self):
self.vocab_sizes = tuple([1024 * 128] * 26)
self.num_dense_features = 13
self.mlp_bottom_dims = (128, 128)
self.mlp_top_dims = (256, 128, 1)
self.embed_dim = 64
vocab_sizes: Tuple[int] = tuple([1024 * 128] * 26)
num_dense_features: int = 13
mlp_bottom_dims: Tuple[int, int] = (128, 128)
mlp_top_dims: Tuple[int, int, int] = (256, 128, 1)
embed_dim: int = 64

def has_reached_goal(self, eval_result: float) -> bool:
return eval_result['validation/loss'] < self.target_value
Expand Down Expand Up @@ -83,11 +82,6 @@ def _build_input_queue(self,
for batch in iter(ds):
yield batch

# Return whether or not a key in spec.ParameterContainer is the output layer
# parameters.
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
pass

@property
def step_hint(self) -> int:
"""Max num steps the target setting algo was given to reach the target."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
params = jax_utils.replicate(params)
return params, None

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Conv_0'

def model_fn(
self,
params: spec.ParameterContainer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def __init__(self,
self.out_chans = out_chans
self.dropout = dropout

self.layers = nn.Sequential(
self.conv_layers = nn.Sequential(
nn.Conv2d(in_chans, out_chans, kernel_size=3, padding=1, bias=False),
nn.InstanceNorm2d(out_chans),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
Expand All @@ -109,7 +109,7 @@ def __init__(self,
)

def forward(self, x: Tensor) -> Tensor:
return self.layers(x)
return self.conv_layers(x)


class TransposeConvBlock(nn.Module):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
model = torch.nn.DataParallel(model)
return model, None

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['up_conv.3.1.weight', 'up_conv.3.1.bias']

def model_fn(
self,
params: spec.ParameterContainer,
Expand Down
5 changes: 0 additions & 5 deletions algorithmic_efficiency/workloads/fastmri/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ def step_hint(self) -> int:
"""Max num steps the target setting algo was given to reach the target."""
return 27142

# Return whether or not a key in spec.ParameterTree is the output layer
# parameters.
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
raise NotImplementedError

def _build_input_queue(self,
data_rng: spec.RandomState,
split: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
params = jax_utils.replicate(params)
return params, model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'Dense_0'

@functools.partial(
jax.pmap,
axis_name='batch',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
model = torch.nn.DataParallel(model)
return model, None

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['fc.weight', 'fc.bias']

def _update_batch_norm(self, model, update_batch_norm):
bn_layers = (nn.BatchNorm1d,
nn.BatchNorm2d,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@
import jax
import numpy as np
import tensorflow_datasets as tfds
import torch

from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax import \
input_pipeline


def _shard(x):
return x.reshape((jax.local_device_count(), -1, *x.shape[1:]))
# If we install the CPU version of a framework it may not return the correct
# number of GPUs.
num_devices = max(torch.cuda.device_count(), jax.local_device_count())
return x.reshape((num_devices, -1, *x.shape[1:]))


def shard_and_maybe_pad_batch(desired_batch_size, shard_batch, tf_batch):
Expand Down
8 changes: 1 addition & 7 deletions algorithmic_efficiency/workloads/imagenet_resnet/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@

class BaseImagenetResNetWorkload(spec.Workload):

def __init__(self):
self._num_classes = 1000
_num_classes: int = 1000

def has_reached_goal(self, eval_result: float) -> bool:
return eval_result['validation/accuracy'] > self.target_value
Expand Down Expand Up @@ -72,11 +71,6 @@ def max_allowed_runtime_sec(self):
def eval_period_time_sec(self):
return 510 # 8.5 minutes.

# Return whether or not a key in spec.ParameterTree is the output layer
# parameters.
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
raise NotImplementedError

def _build_input_queue(self,
data_rng: spec.RandomState,
split: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
params = jax_utils.replicate(params)
return params, model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key == 'pre_logits'

def model_fn(
self,
params: spec.ParameterContainer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
model = torch.nn.DataParallel(model)
return model, None

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
return param_key in ['pre_logits.weight', 'pre_logits.bias']

def model_fn(
self,
params: spec.ParameterContainer,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
params = jax_utils.replicate(params)
return params, model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
pass

def init_tokenizer(self, tokenizer_vocab_path):
logging.info('Initializing metrics bundle and tokenizer.')
self.metrics_bundle = metrics.get_metrics_bundle(tokenizer_vocab_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
model = torch.nn.DataParallel(model)
return model, None

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
pass

def init_tokenizer(self, tokenizer_vocab_path):
logging.info('Initializing tokenizer.')
self.tokenizer = metrics.load_tokenizer(tokenizer_vocab_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@

class BaseLibrispeechWorkload(spec.Workload):

def __init__(self) -> None:
self._num_outputs = 1024
_num_outputs: int = 1024

def has_reached_goal(self, eval_result: float) -> bool:
return eval_result['validation/wer'] < self.target_value
Expand Down Expand Up @@ -163,11 +162,6 @@ def _build_dataset(self,

yield batch

# Return whether or not a key in spec.ParameterContainer is the output layer
# parameters.
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
pass

@property
def step_hint(self) -> int:
"""Max num steps the target setting algo was given to reach the target."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
params = jax_utils.replicate(params)
return params, model_state

def is_output_params(self, param_key: spec.ParameterKey) -> bool:
pass

def model_fn(
self,
params: spec.ParameterContainer,
Expand Down
14 changes: 4 additions & 10 deletions algorithmic_efficiency/workloads/mnist/mnist_jax/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ def _param_types(param_tree):

class MnistWorkload(BaseMnistWorkload):

def __init__(self):
super().__init__()
self._model = _Model()

def _normalize(self, image):
return (tf.cast(image, tf.float32) - self.train_mean) / self.train_stddev

Expand Down Expand Up @@ -82,10 +78,8 @@ def _build_dataset(self,
ds = map(data_utils.shard_numpy_ds, ds)
return iter(ds)

# Return whether or not a key in spec.ParameterContainer is the output layer
# parameters.
def is_output_params(self, param_key: spec.ParameterKey) -> bool:
pass
return param_key == 'Dense_1'

def _build_input_queue(self,
data_rng,
Expand All @@ -100,8 +94,8 @@ def _build_input_queue(self,

def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
initial_params = self._model.init({'params': rng}, init_val,
train=True)['params']
initial_params = _Model().init({'params': rng}, init_val,
train=True)['params']
self._param_shapes = param_utils.jax_param_shapes(initial_params)
self._param_types = param_utils.jax_param_types(self._param_shapes)
return jax_utils.replicate(initial_params), None
Expand All @@ -123,7 +117,7 @@ def model_fn(
del aux_dropout_rate
del update_batch_norm
train = mode == spec.ForwardPassMode.TRAIN
logits_batch = self._model.apply(
logits_batch = _Model().apply(
{'params': params},
augmented_and_preprocessed_input_batch['inputs'],
train=train)
Expand Down
Loading

0 comments on commit 1177e36

Please sign in to comment.