Skip to content

Commit

Permalink
Merge pull request mlcommons#502 from mlcommons/juhan/legit_fix
Browse files Browse the repository at this point in the history
Fix for Criteo OOM
  • Loading branch information
priyakasimbeg authored Aug 28, 2023
2 parents 7eeb235 + f9790ff commit 4c38ffb
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def dot_interact(concat_features):
"""
batch_size = concat_features.shape[0]

# Interact features, select upper or lower-triangular portion, and re-shape.
# Interact features, select upper or lower-triangular portion, and reshape.
xactions = jnp.matmul(concat_features,
jnp.transpose(concat_features, [0, 2, 1]))
feature_dim = xactions.shape[-1]
Expand Down Expand Up @@ -46,7 +46,7 @@ class DlrmSmall(nn.Module):
embed_dim: embedding dimension.
"""

vocab_size: int = 32 * 128 * 1024 # 4_194_304
vocab_size: int = 32 * 128 * 1024 # 4_194_304.
num_dense_features: int = 13
mlp_bottom_dims: Sequence[int] = (512, 256, 128)
mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1)
Expand Down
102 changes: 54 additions & 48 deletions algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,23 @@
from torch import nn


def dot_interact(concat_features):
"""Performs feature interaction operation between dense or sparse features.
Input tensors represent dense or sparse features.
Pre-condition: The tensors have been stacked along dimension 1.
Args:
concat_features: Array of features with shape [B, n_features, feature_dim].
Returns:
activations: Array representing interacted features.
"""
batch_size = concat_features.shape[0]

# Interact features, select upper or lower-triangular portion, and re-shape.
xactions = torch.bmm(concat_features,
torch.permute(concat_features, (0, 2, 1)))
feature_dim = xactions.shape[-1]

indices = torch.triu_indices(feature_dim, feature_dim)
num_elems = indices.shape[1]
indices = torch.tile(indices, [1, batch_size])
indices0 = torch.reshape(
torch.tile(
torch.reshape(torch.arange(batch_size), [-1, 1]), [1, num_elems]),
[1, -1])
indices = tuple(torch.cat((indices0, indices), 0))
activations = xactions[indices]
activations = torch.reshape(activations, [batch_size, -1])
return activations
class DotInteract(nn.Module):
"""Performs feature interaction operation between dense or sparse features."""

def __init__(self, num_sparse_features):
super().__init__()
self.triu_indices = torch.triu_indices(num_sparse_features + 1,
num_sparse_features + 1)

def forward(self, dense_features, sparse_features):
combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features),
dim=1)
interactions = torch.bmm(combined_values,
torch.transpose(combined_values, 1, 2))
interactions_flat = interactions[:,
self.triu_indices[0],
self.triu_indices[1]]
return torch.cat((dense_features, interactions_flat), dim=1)


class DlrmSmall(nn.Module):
Expand Down Expand Up @@ -62,13 +52,21 @@ def __init__(self,
self.mlp_top_dims = mlp_top_dims
self.embed_dim = embed_dim

self.embedding_table = nn.Embedding(self.vocab_size, self.embed_dim)
self.embedding_table.weight.data.uniform_(0, 1)
# Scale the initialization to fan_in for each slice.
# Ideally, we should use the pooled embedding implementation from
# `TorchRec`. However, in order to have identical implementation
# with that of Jax, we define a single embedding matrix.
num_chucks = 4
assert vocab_size % num_chucks == 0
self.embedding_table_chucks = []
scale = 1.0 / torch.sqrt(self.vocab_size)
self.embedding_table.weight.data = scale * self.embedding_table.weight.data
for i in range(num_chucks):
chunk = nn.Parameter(
torch.Tensor(self.vocab_size // num_chucks, self.embed_dim))
chunk.data.uniform_(0, 1)
chunk.data = scale * chunk.data
self.register_parameter(f'embedding_chunk_{i}', chunk)
self.embedding_table_chucks.append(chunk)

# bottom mlp
bottom_mlp_layers = []
input_dim = self.num_dense_features
for dense_dim in self.mlp_bottom_dims:
Expand All @@ -84,8 +82,9 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))

# top mlp
# TODO (JB): Write down the formula here instead of the constant.
self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,)

# TODO: Write down the formula here instead of the constant.
input_dims = 506
top_mlp_layers = []
num_layers_top = len(self.mlp_top_dims)
Expand All @@ -110,19 +109,26 @@ def __init__(self,
math.sqrt(1. / module.out_features))

def forward(self, x):
bot_mlp_input, cat_features = torch.split(
batch_size = x.shape[0]

dense_features, sparse_features = torch.split(
x, [self.num_dense_features, self.num_sparse_features], 1)
cat_features = cat_features.to(dtype=torch.int32)
bot_mlp_output = self.bot_mlp(bot_mlp_input)
batch_size = bot_mlp_output.shape[0]
feature_stack = torch.reshape(bot_mlp_output,
[batch_size, -1, self.embed_dim])
idx_lookup = torch.reshape(cat_features, [-1]) % self.vocab_size
embed_features = self.embedding_table(idx_lookup)
embed_features = torch.reshape(embed_features,
[batch_size, -1, self.embed_dim])
feature_stack = torch.cat([feature_stack, embed_features], axis=1)
dot_interact_output = dot_interact(concat_features=feature_stack)
top_mlp_input = torch.cat([bot_mlp_output, dot_interact_output], axis=-1)
logits = self.top_mlp(top_mlp_input)

# Bottom MLP.
embedded_dense = self.bot_mlp(dense_features)

# Sparse feature processing.
sparse_features = sparse_features.to(dtype=torch.int32)
idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
embedded_sparse = embedding_table[idx_lookup]
embedded_sparse = torch.reshape(embedded_sparse,
[batch_size, -1, self.embed_dim])

# Dot product interactions.
concatenated_dense = self.dot_interact(
dense_features=embedded_dense, sparse_features=embedded_sparse)

# Final MLP.
logits = self.top_mlp(concatenated_dense)
return logits
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Criteo1TB workload implemented in PyTorch."""

import contextlib
from typing import Dict, Optional, Tuple
from typing import Dict, Iterator, Optional, Tuple

import jax
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
Expand All @@ -23,7 +22,7 @@ class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):

@property
def eval_batch_size(self) -> int:
return 262_144
return 32_768

def _per_example_sigmoid_binary_cross_entropy(
self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor:
Expand Down Expand Up @@ -67,11 +66,6 @@ def loss_fn(
'per_example': per_example_losses,
}

def _eval_metric(self, logits: spec.Tensor,
targets: spec.Tensor) -> Dict[str, int]:
summed_loss = self.loss_fn(logits, targets)['summed']
return {'loss': summed_loss}

def init_model_fn(
self,
rng: spec.RandomState,
Expand Down Expand Up @@ -133,25 +127,28 @@ def model_fn(

return logits_batch, None

def _build_input_queue(self,
data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
global_batch_size: int,
num_batches: Optional[int] = None,
repeat_final_dataset: bool = False):
def _build_input_queue(
self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
not_train = split != 'train'
per_device_batch_size = int(global_batch_size / N_GPUS)

# Only create and iterate over tf input pipeline in one Python process to
# avoid creating too many threads.
if RANK == 0:
np_iter = super()._build_input_queue(data_rng,
split,
data_dir,
global_batch_size,
num_batches,
repeat_final_dataset)
np_iter = super()._build_input_queue(
data_rng=data_rng,
split=split,
data_dir=data_dir,
global_batch_size=global_batch_size,
num_batches=num_batches,
repeat_final_dataset=repeat_final_dataset)
weights = None
while True:
if RANK == 0:
Expand Down
19 changes: 6 additions & 13 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class BaseCriteo1TbDlrmSmallWorkload(spec.Workload):
"""Criteo1tb workload."""

vocab_size: int = 32 * 128 * 1024 # 4_194_304
vocab_size: int = 32 * 128 * 1024 # 4_194_304.
num_dense_features: int = 13
mlp_bottom_dims: Tuple[int, int] = (512, 256, 128)
mlp_top_dims: Tuple[int, int, int] = (1024, 1024, 512, 256, 1)
Expand Down Expand Up @@ -128,11 +128,11 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators will repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
rng,
split,
data_dir,
global_batch_size,
num_batches,
data_rng=rng,
split=split,
data_dir=data_dir,
global_batch_size=global_batch_size,
num_batches=num_batches,
repeat_final_dataset=True)
loss = 0.0
for _ in range(num_batches):
Expand All @@ -141,11 +141,4 @@ def _eval_model_on_split(self,
if USE_PYTORCH_DDP:
dist.all_reduce(loss)
mean_loss = loss.item() / num_examples
if FLAGS.framework == 'pytorch':
# For PyTorch, the saved iterators cause OOM after evaluation.
# Hence, we create new iterators for each evaluation step. While this
# slows down the overall time to perform evaluation, this does not affect
# the final score.
del self._eval_iters
self._eval_iters = {}
return {'loss': mean_loss}
34 changes: 20 additions & 14 deletions submission_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ def _get_time_ddp():
get_time = _get_time


def _reset_cuda_mem():
if FLAGS.framework == 'pytorch' and torch.cuda.is_available():
torch._C._cuda_clearCublasWorkspaces()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()


def train_once(
workload: spec.Workload,
global_batch_size: int,
Expand Down Expand Up @@ -192,11 +200,11 @@ def train_once(
model_params, model_state = workload.init_model_fn(
model_init_rng, dropout_rate, aux_dropout_rate)
if FLAGS.framework == 'pytorch' and FLAGS.torch_compile:
compile_error_workloads = ['ogbg']
compile_error_workloads = ['ogbg', 'criteo1tb']
eager_backend_workloads = [
'librispeech_conformer', 'librispeech_deepspeech'
]
aot_eager_backend_workloads = ['criteo1tb']
aot_eager_backend_workloads = []
if FLAGS.workload in compile_error_workloads:
logging.warning(
'These workloads cannot be fully compiled under current '
Expand Down Expand Up @@ -325,6 +333,9 @@ def train_once(
if ((train_step_end_time - train_state['last_eval_time']) >=
workload.eval_period_time_sec or train_state['training_complete']):
with profiler.profile('Evaluation'):
del batch
_reset_cuda_mem()

try:
eval_start_time = get_time()
latest_eval_result = workload.eval_model(global_eval_batch_size,
Expand All @@ -334,23 +345,23 @@ def train_once(
data_dir,
imagenet_v2_data_dir,
global_step)
# Check if targets reached
# Check if targets reached.
train_state['validation_goal_reached'] = (
workload.has_reached_validation_target(latest_eval_result) or
train_state['validation_goal_reached'])
train_state['test_goal_reached'] = (
workload.has_reached_test_target(latest_eval_result) or
train_state['test_goal_reached'])

# Save last eval time
# Save last eval time.
eval_end_time = get_time()
train_state['last_eval_time'] = eval_end_time

# Accumulate eval time
# Accumulate eval time.
train_state[
'accumulated_eval_time'] += eval_end_time - eval_start_time

# Add times to eval results for logging
# Add times to eval results for logging.
latest_eval_result['score'] = (
train_state['accumulated_submission_time'])
latest_eval_result[
Expand Down Expand Up @@ -389,23 +400,18 @@ def train_once(
save_intermediate_checkpoints=FLAGS
.save_intermediate_checkpoints)

if FLAGS.framework == 'pytorch' and torch.cuda.is_available():
# Clean up the GPU cache after evaluation.
gc.collect()
torch.cuda.empty_cache()
logging.info('Released all unoccupied cached memory.')

logging_end_time = get_time()
train_state['accumulated_logging_time'] += (
logging_end_time - logging_start_time)

_reset_cuda_mem()

except RuntimeError as e:
logging.exception(f'Eval step {global_step} error.\n')
if 'out of memory' in str(e):
logging.warning('Error: GPU out of memory during eval during step '
f'{global_step}, error : {str(e)}.')
if torch.cuda.is_available():
torch.cuda.empty_cache()
_reset_cuda_mem()

train_state['last_step_end_time'] = get_time()

Expand Down
18 changes: 14 additions & 4 deletions tests/test_param_shapes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from itertools import zip_longest

import jax
import numpy as np
import pytest
Expand Down Expand Up @@ -53,13 +55,21 @@ def test_param_shapes(workload):
jax_workload.param_shapes.unfreeze())
pytorch_param_shapes = jax.tree_util.tree_leaves(
pytorch_workload.param_shapes)
assert len(jax_param_shapes) == len(pytorch_param_shapes)
if workload == 'criteo1tb':
# The PyTorch implementation divides the embedding matrix
# into 3 chunks.
assert len(jax_param_shapes) == len(pytorch_param_shapes) - 3
else:
assert len(jax_param_shapes) == len(pytorch_param_shapes)
# Check if total number of params deduced from shapes match.
num_jax_params = 0
num_pytorch_params = 0
for jax_shape, pytorch_shape in zip(jax_param_shapes, pytorch_param_shapes):
num_jax_params += np.prod(jax_shape.shape_tuple)
num_pytorch_params += np.prod(pytorch_shape.shape_tuple)
for jax_shape, pytorch_shape in zip_longest(jax_param_shapes,
pytorch_param_shapes):
if jax_shape is not None:
num_jax_params += np.prod(jax_shape.shape_tuple)
if pytorch_shape is not None:
num_pytorch_params += np.prod(pytorch_shape.shape_tuple)
assert num_jax_params == num_pytorch_params


Expand Down
5 changes: 5 additions & 0 deletions tests/test_param_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ def test_param_types(workload_name):
jax_param_types_dict = count_param_types(jax_param_types)
pytorch_param_types_dict = count_param_types(pytorch_param_types)

# PyTorch splits embedding matrix into 3 chunks.
if workload_name == 'criteo1tb':
pytorch_param_types_dict[spec.ParameterType.WEIGHT] -= 4
pytorch_param_types_dict[spec.ParameterType.EMBEDDING] = 1

# Jax fuses LSTM cells together, whereas PyTorch exposes all the weight
# parameters, and there are two per cell, for each of the forward and backward
# directional LSTMs, and there are 6 layers of LSTM in librispeech_deepspeech,
Expand Down

0 comments on commit 4c38ffb

Please sign in to comment.