Skip to content

Commit

Permalink
Merge branch 'dev' into wmt-speed
Browse files Browse the repository at this point in the history
Conflicts:
	algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py
  • Loading branch information
runame committed Aug 16, 2023
2 parents 7308c91 + 95190b5 commit 8d99d81
Show file tree
Hide file tree
Showing 11 changed files with 95 additions and 74 deletions.
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ To use the Docker container as an interactive virtual environment, you can run a
--gpus all \
--ipc=host \
<docker_image_name>
-keep_container_alive true
--keep_container_alive true
```
2. Open a bash terminal
```bash
Expand All @@ -148,8 +148,8 @@ python3 submission_runner.py \
--workload=mnist \
--experiment_dir=$HOME/experiments \
--experiment_name=my_first_experiment \
--submission_path=reference_algorithms/development_algorithms/mnist/mnist_jax/submission.py \
--tuning_search_space=reference_algorithms/development_algorithms/mnist/tuning_search_space.json
--submission_path=baselines/adamw/jax/submission.py \
--tuning_search_space=baselines/adamw/tuning_search_space.json
```

**Pytorch**
Expand All @@ -160,8 +160,8 @@ python3 submission_runner.py \
--workload=mnist \
--experiment_dir=$HOME/experiments \
--experiment_name=my_first_experiment \
--submission_path=reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py \
--tuning_search_space=reference_algorithms/development_algorithms/mnist/tuning_search_space.json
--submission_path=baselines/adamw/jax/submission.py \
--tuning_search_space=baselines/adamw/tuning_search_space.json
```
<details>
<summary>
Expand All @@ -186,10 +186,10 @@ torchrun --redirects 1:0,2:0,3:0,4:0,5:0,6:0,7:0 --standalone --nnodes=1 --nproc
submission_runner.py \
--framework=pytorch \
--workload=mnist \
--experiment_dir=/home/znado \
--experiment_dir=$HOME/experiments \
--experiment_name=baseline \
--submission_path=reference_algorithms/development_algorithms/mnist/mnist_pytorch/submission.py \
--tuning_search_space=reference_algorithms/development_algorithms/mnist/tuning_search_space.json \
--submission_path=baselines/adamw/jax/submission.py \
--tuning_search_space=baselines/adamw/tuning_search_space.json
```
</details>

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Criteo1TB workload implemented in Jax."""

import functools
from typing import Dict, Optional, Tuple

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Criteo1TB workload implemented in PyTorch."""

import contextlib
from typing import Dict, Optional, Tuple

Expand Down Expand Up @@ -79,6 +80,8 @@ def init_model_fn(
"""Only dropout is used."""
del aux_dropout_rate
torch.random.manual_seed(rng[0])
# Disable cudnn benchmark to avoid OOM errors.
torch.backends.cudnn.benchmark = False
model = DlrmSmall(
vocab_size=self.vocab_size,
num_dense_features=self.num_dense_features,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
validation). See here for the NVIDIA example:
https://github.com/NVIDIA/DeepLearningExamples/blob/4e764dcd78732ebfe105fc05ea3dc359a54f6d5e/PyTorch/Recommendation/DLRM/preproc/run_spark_cpu.sh#L119.
"""

import functools
import os
from typing import Optional
Expand Down
42 changes: 28 additions & 14 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
"""Criteo1TB DLRM workload base class."""

import math
import os
from typing import Dict, Optional, Tuple
from typing import Dict, Iterator, Optional, Tuple

import jax
from absl import flags
import torch.distributed as dist

from algorithmic_efficiency import spec
from algorithmic_efficiency.workloads.criteo1tb import input_pipeline

FLAGS = flags.FLAGS

USE_PYTORCH_DDP = 'LOCAL_RANK' in os.environ


Expand All @@ -26,14 +29,15 @@ def target_metric_name(self) -> str:
"""The name of the target metric (useful for scoring/processing code)."""
return 'loss'

def has_reached_validation_target(self, eval_result: float) -> bool:
def has_reached_validation_target(self, eval_result: Dict[str,
float]) -> bool:
return eval_result['validation/loss'] < self.validation_target_value

@property
def validation_target_value(self) -> float:
return 0.123649

def has_reached_test_target(self, eval_result: float) -> bool:
def has_reached_test_target(self, eval_result: Dict[str, float]) -> bool:
return eval_result['test/loss'] < self.test_target_value

@property
Expand Down Expand Up @@ -75,19 +79,22 @@ def train_stddev(self):

@property
def max_allowed_runtime_sec(self) -> int:
return 7703 # ~2 hours
return 7703 # ~2 hours.

@property
def eval_period_time_sec(self) -> int:
return 2 * 60

def _build_input_queue(self,
data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
global_batch_size: int,
num_batches: Optional[int] = None,
repeat_final_dataset: bool = False):
return 2 * 600 # 20 mins.

def _build_input_queue(
self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
del cache
ds = input_pipeline.get_criteo1tb_dataset(
split=split,
shuffle_rng=data_rng,
Expand Down Expand Up @@ -134,4 +141,11 @@ def _eval_model_on_split(self,
if USE_PYTORCH_DDP:
dist.all_reduce(loss)
mean_loss = loss.item() / num_examples
if FLAGS.framework == 'pytorch':
# For PyTorch, the saved iterators cause OOM after evaluation.
# Hence, we create new iterators for each evaluation step. While this
# slows down the overall time to perform evaluation, this does not affect
# the final score.
del self._eval_iters
self._eval_iters = {}
return {'loss': mean_loss}
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def __init__(self, config: DeepspeechConfig):

def forward(self, inputs, input_paddings):
inputs = self.bn(inputs, input_paddings)
lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu()
lengths = torch.sum(1 - input_paddings, dim=1).detach().cpu().numpy()
packed_inputs = torch.nn.utils.rnn.pack_padded_sequence(
inputs, lengths, batch_first=True, enforce_sorted=False)
packed_outputs, _ = self.lstm(packed_inputs)
Expand Down
68 changes: 26 additions & 42 deletions algorithmic_efficiency/workloads/wmt/wmt_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,6 @@
from torch.nn.init import xavier_uniform_


# Mask making utilities ported to PyTorch from
# https://github.com/google/flax/blob/main/flax/linen/attention.py.
def make_attention_mask(query_input: Tensor,
key_input: Tensor,
pairwise_fn: Callable[..., Any] = torch.mul,
dtype: torch.dtype = torch.float32) -> Tensor:
"""Mask-making helper for attention weights.
Args:
query_input: a batched, flat input of query_length size
key_input: a batched, flat input of key_length size
pairwise_fn: broadcasting elementwise comparison function
dtype: mask return dtype
Returns:
A `[batch..., len_q, len_kv]` shaped attention mask.
"""
mask = pairwise_fn(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.to(dtype)


def make_causal_mask(x: Tensor,
device: str = 'cuda:0',
dtype: torch.dtype = torch.float32) -> Tensor:
Expand All @@ -46,17 +25,21 @@ def make_causal_mask(x: Tensor,
"""
idxs = torch.broadcast_to(
torch.arange(x.shape[-1], dtype=torch.int32, device=device), x.shape)
return make_attention_mask(idxs, idxs, torch.greater_equal, dtype=dtype)
return torch.greater_equal(idxs.unsqueeze(-1),
idxs.unsqueeze(-2)).to(dtype=dtype)


def make_src_mask(src, inputs_segmentation, nhead):
"""Utility for creating src mask and adjust it for PyTorch Transformer API."""
src_mask = make_attention_mask(src > 0, src > 0)
src_mask = torch.mul((src > 0).unsqueeze(-1),
(src > 0).unsqueeze(-2)).to(dtype=torch.float32)
# Add segmentation block-diagonal attention mask if using segmented data.
if inputs_segmentation is not None:
src_mask = torch.logical_and(
src_mask,
make_attention_mask(inputs_segmentation, inputs_segmentation, torch.eq))
torch.eq(
inputs_segmentation.unsqueeze(-1),
inputs_segmentation.unsqueeze(-2)).to(dtype=torch.float32))
# Flip values and ensure numerical stability.
src_mask = torch.repeat_interleave(
torch.logical_not(src_mask), repeats=nhead, dim=0)
Expand All @@ -75,23 +58,27 @@ def make_tgt_and_memory_mask(tgt,
Transformer API."""
if not decode:
tgt_mask = torch.logical_and(
make_attention_mask(tgt > 0, tgt > 0),
torch.mul((tgt > 0).unsqueeze(-1),
(tgt > 0).unsqueeze(-2)).to(dtype=torch.float32),
make_causal_mask(tgt, device=tgt.device))
memory_mask = make_attention_mask(tgt > 0, src > 0)
memory_mask = torch.mul((tgt > 0).unsqueeze(-1),
(src > 0).unsqueeze(-2)).to(dtype=torch.float32)
else:
tgt_mask = None
memory_mask = make_attention_mask(torch.ones_like(tgt) > 0, src > 0)
memory_mask = torch.mul((torch.ones_like(tgt) > 0).unsqueeze(-1),
(src > 0).unsqueeze(-2)).to(dtype=torch.float32)
# Add segmentation block-diagonal attention masks if using segmented data.
if inputs_segmentation is not None:
tgt_mask = torch.logical_and(
tgt_mask,
make_attention_mask(targets_segmentation,
targets_segmentation,
torch.eq))
torch.eq(
targets_segmentation.unsqueeze(-1),
targets_segmentation.unsqueeze(-2)).to(dtype=torch.float32))
memory_mask = torch.logical_and(
memory_mask,
make_attention_mask(targets_segmentation, inputs_segmentation,
torch.eq))
torch.eq(
targets_segmentation.unsqueeze(-1),
inputs_segmentation.unsqueeze(-2)).to(dtype=torch.float32))
# Flip values and ensure numerical stability.
memory_mask = torch.repeat_interleave(
torch.logical_not(memory_mask), repeats=nhead, dim=0)
Expand Down Expand Up @@ -452,8 +439,7 @@ class TransformerEncoderLayer(nn.Module):
feedforward operations, respectivaly. Otherwise it's done after.
Default: ``True``.
Examples::
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8,
batch_first=True)
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(32, 10, 512)
>>> out = encoder_layer(src)
"""
Expand Down Expand Up @@ -600,7 +586,7 @@ def forward(self,
memory: the sequence from the last layer of the encoder (required).
tgt_mask: the mask for the tgt sequence (optional).
memory_mask: the mask for the memory sequence (optional).
decode: wether to use cache for autoregressive decoding or not.
decode: whether to use cache for autoregressive decoding or not.
max_len: maximum sequence length, necessary for decoding cache.
Shape:
see the docs in Transformer class.
Expand Down Expand Up @@ -651,13 +637,12 @@ class TransformerDecoderLayer(nn.Module):
multihead attention and feedforward operations, respectivaly.
Otherwise it's done after. Default: ``True``.
Examples::
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8,
batch_first=True)
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8)
>>> memory = torch.rand(32, 10, 512)
>>> tgt = torch.rand(32, 20, 512)
>>> out = decoder_layer(tgt, memory)
"""
__constants__ = ['batch_first', 'norm_first']
__constants__ = ['norm_first']

def __init__(self,
d_model: int = 1024,
Expand Down Expand Up @@ -880,10 +865,9 @@ def forward(self,
cache: cache dictionary for autoregressive decoding.
index: index of the current decoding step, necessary for decoding cache.
Outputs:
- **attn_output** - Attention outputs of shape :math:`(N, L, E)` when
``batch_first=True``, where :math:`L` is the target sequence length,
:math:`N` is the batch size, and :math:`E` is the embedding dimension
``embed_dim``.
- **attn_output** - Attention outputs of shape :math:`(N, L, E)`, where
:math:`L` is the target sequence length, :math:`N` is the batch size,
and :math:`E` is the embedding dimension ``embed_dim``.
- **cache** - For autoregressive decoding.
"""
# Shape: (batch size, sequence length, embedding dimensionality)
Expand Down
20 changes: 15 additions & 5 deletions algorithmic_efficiency/workloads/wmt/wmt_pytorch/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,17 +77,27 @@ def predict_step(self,
max_decode_len: int,
beam_size: int = 4) -> spec.Tensor:
"""Predict translation with fast decoding beam search on a batch."""
params = params.module if isinstance(params, (DP, DDP)) else params
# params = params.module if isinstance(params, (DP, DDP)) else params
if hasattr(params, 'module'):
params = params.module
params.eval()
encoder = params.encoder

if hasattr(params, '_modules'):
params = params._modules
encoder = params["encoder"]
decoder = params["decoder"]
else:
encoder = params.encoder
decoder = params.decoder

if N_GPUS > 1 and not USE_PYTORCH_DDP:
encoder = DP(encoder)
if N_GPUS > 1 and not USE_PYTORCH_DDP:
decoder = DP(decoder)

encoded_inputs = torch.repeat_interleave(
encoder(inputs), repeats=beam_size, dim=0)
raw_inputs = torch.repeat_interleave(inputs, repeats=beam_size, dim=0)
decoder = params.decoder
if N_GPUS > 1 and not USE_PYTORCH_DDP:
decoder = DP(decoder)

def tokens_ids_to_logits(
flat_ids: spec.Tensor, flat_cache: Dict[str, spec.Tensor]
Expand Down
2 changes: 1 addition & 1 deletion baselines/adamw/pytorch/submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def update_params(workload: spec.Workload,
optimizer_state['scheduler'].step()

# Log training metrics - loss, grad_norm, batch_size.
if global_step <= 100 or global_step % 500 == 0:
if global_step <= 10 or global_step % 500 == 0:
with torch.no_grad():
parameters = [p for p in current_model.parameters() if p.grad is not None]
grad_norm = torch.norm(
Expand Down
2 changes: 2 additions & 0 deletions datasets/dataset_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@
import requests
import tqdm

import tensorflow as tf

IMAGENET_TRAIN_TAR_FILENAME = 'ILSVRC2012_img_train.tar'
IMAGENET_VAL_TAR_FILENAME = 'ILSVRC2012_img_val.tar'

Expand Down
12 changes: 9 additions & 3 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""

import datetime
import gc
import importlib
import json
import os
Expand Down Expand Up @@ -191,8 +192,10 @@ def train_once(
model_params, model_state = workload.init_model_fn(
model_init_rng, dropout_rate, aux_dropout_rate)
if FLAGS.framework == 'pytorch' and FLAGS.torch_compile:
compile_error_workloads = ['ogbg', 'librispeech_deepspeech', 'wmt']
eager_backend_workloads = ['librispeech_conformer']
compile_error_workloads = ['ogbg']
eager_backend_workloads = [
'librispeech_conformer', 'librispeech_deepspeech'
]
aot_eager_backend_workloads = ['criteo1tb']
if FLAGS.workload in compile_error_workloads:
logging.warning(
Expand Down Expand Up @@ -387,9 +390,12 @@ def train_once(
.save_intermediate_checkpoints)

if FLAGS.framework == 'pytorch' and torch.cuda.is_available():
# Clean up the GPU cache after evaluation.
gc.collect()
torch.cuda.empty_cache()
logging_end_time = get_time()
logging.info('Released all unoccupied cached memory.')

logging_end_time = get_time()
train_state['accumulated_logging_time'] += (
logging_end_time - logging_start_time)

Expand Down

0 comments on commit 8d99d81

Please sign in to comment.