Skip to content

Commit

Permalink
Merge branch 'dev' into wmt-speed
Browse files Browse the repository at this point in the history
Conflicts:
	tests/test_param_shapes.py
  • Loading branch information
runame committed Aug 29, 2023
2 parents 8d99d81 + 4c38ffb commit 04ffcc0
Show file tree
Hide file tree
Showing 12 changed files with 283 additions and 172 deletions.
111 changes: 79 additions & 32 deletions .github/workflows/regression_tests.yml

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def dot_interact(concat_features):
"""
batch_size = concat_features.shape[0]

# Interact features, select upper or lower-triangular portion, and re-shape.
# Interact features, select upper or lower-triangular portion, and reshape.
xactions = jnp.matmul(concat_features,
jnp.transpose(concat_features, [0, 2, 1]))
feature_dim = xactions.shape[-1]
Expand Down Expand Up @@ -46,7 +46,7 @@ class DlrmSmall(nn.Module):
embed_dim: embedding dimension.
"""

vocab_size: int = 32 * 128 * 1024 # 4_194_304
vocab_size: int = 32 * 128 * 1024 # 4_194_304.
num_dense_features: int = 13
mlp_bottom_dims: Sequence[int] = (512, 256, 128)
mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1)
Expand Down
102 changes: 54 additions & 48 deletions algorithmic_efficiency/workloads/criteo1tb/criteo1tb_pytorch/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,23 @@
from torch import nn


def dot_interact(concat_features):
"""Performs feature interaction operation between dense or sparse features.
Input tensors represent dense or sparse features.
Pre-condition: The tensors have been stacked along dimension 1.
Args:
concat_features: Array of features with shape [B, n_features, feature_dim].
Returns:
activations: Array representing interacted features.
"""
batch_size = concat_features.shape[0]

# Interact features, select upper or lower-triangular portion, and re-shape.
xactions = torch.bmm(concat_features,
torch.permute(concat_features, (0, 2, 1)))
feature_dim = xactions.shape[-1]

indices = torch.triu_indices(feature_dim, feature_dim)
num_elems = indices.shape[1]
indices = torch.tile(indices, [1, batch_size])
indices0 = torch.reshape(
torch.tile(
torch.reshape(torch.arange(batch_size), [-1, 1]), [1, num_elems]),
[1, -1])
indices = tuple(torch.cat((indices0, indices), 0))
activations = xactions[indices]
activations = torch.reshape(activations, [batch_size, -1])
return activations
class DotInteract(nn.Module):
"""Performs feature interaction operation between dense or sparse features."""

def __init__(self, num_sparse_features):
super().__init__()
self.triu_indices = torch.triu_indices(num_sparse_features + 1,
num_sparse_features + 1)

def forward(self, dense_features, sparse_features):
combined_values = torch.cat((dense_features.unsqueeze(1), sparse_features),
dim=1)
interactions = torch.bmm(combined_values,
torch.transpose(combined_values, 1, 2))
interactions_flat = interactions[:,
self.triu_indices[0],
self.triu_indices[1]]
return torch.cat((dense_features, interactions_flat), dim=1)


class DlrmSmall(nn.Module):
Expand Down Expand Up @@ -62,13 +52,21 @@ def __init__(self,
self.mlp_top_dims = mlp_top_dims
self.embed_dim = embed_dim

self.embedding_table = nn.Embedding(self.vocab_size, self.embed_dim)
self.embedding_table.weight.data.uniform_(0, 1)
# Scale the initialization to fan_in for each slice.
# Ideally, we should use the pooled embedding implementation from
# `TorchRec`. However, in order to have identical implementation
# with that of Jax, we define a single embedding matrix.
num_chucks = 4
assert vocab_size % num_chucks == 0
self.embedding_table_chucks = []
scale = 1.0 / torch.sqrt(self.vocab_size)
self.embedding_table.weight.data = scale * self.embedding_table.weight.data
for i in range(num_chucks):
chunk = nn.Parameter(
torch.Tensor(self.vocab_size // num_chucks, self.embed_dim))
chunk.data.uniform_(0, 1)
chunk.data = scale * chunk.data
self.register_parameter(f'embedding_chunk_{i}', chunk)
self.embedding_table_chucks.append(chunk)

# bottom mlp
bottom_mlp_layers = []
input_dim = self.num_dense_features
for dense_dim in self.mlp_bottom_dims:
Expand All @@ -84,8 +82,9 @@ def __init__(self,
0.,
math.sqrt(1. / module.out_features))

# top mlp
# TODO (JB): Write down the formula here instead of the constant.
self.dot_interact = DotInteract(num_sparse_features=num_sparse_features,)

# TODO: Write down the formula here instead of the constant.
input_dims = 506
top_mlp_layers = []
num_layers_top = len(self.mlp_top_dims)
Expand All @@ -110,19 +109,26 @@ def __init__(self,
math.sqrt(1. / module.out_features))

def forward(self, x):
bot_mlp_input, cat_features = torch.split(
batch_size = x.shape[0]

dense_features, sparse_features = torch.split(
x, [self.num_dense_features, self.num_sparse_features], 1)
cat_features = cat_features.to(dtype=torch.int32)
bot_mlp_output = self.bot_mlp(bot_mlp_input)
batch_size = bot_mlp_output.shape[0]
feature_stack = torch.reshape(bot_mlp_output,
[batch_size, -1, self.embed_dim])
idx_lookup = torch.reshape(cat_features, [-1]) % self.vocab_size
embed_features = self.embedding_table(idx_lookup)
embed_features = torch.reshape(embed_features,
[batch_size, -1, self.embed_dim])
feature_stack = torch.cat([feature_stack, embed_features], axis=1)
dot_interact_output = dot_interact(concat_features=feature_stack)
top_mlp_input = torch.cat([bot_mlp_output, dot_interact_output], axis=-1)
logits = self.top_mlp(top_mlp_input)

# Bottom MLP.
embedded_dense = self.bot_mlp(dense_features)

# Sparse feature processing.
sparse_features = sparse_features.to(dtype=torch.int32)
idx_lookup = torch.reshape(sparse_features, [-1]) % self.vocab_size
embedding_table = torch.cat(self.embedding_table_chucks, dim=0)
embedded_sparse = embedding_table[idx_lookup]
embedded_sparse = torch.reshape(embedded_sparse,
[batch_size, -1, self.embed_dim])

# Dot product interactions.
concatenated_dense = self.dot_interact(
dense_features=embedded_dense, sparse_features=embedded_sparse)

# Final MLP.
logits = self.top_mlp(concatenated_dense)
return logits
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
"""Criteo1TB workload implemented in PyTorch."""

import contextlib
from typing import Dict, Optional, Tuple
from typing import Dict, Iterator, Optional, Tuple

import jax
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
Expand All @@ -23,7 +22,7 @@ class Criteo1TbDlrmSmallWorkload(BaseCriteo1TbDlrmSmallWorkload):

@property
def eval_batch_size(self) -> int:
return 262_144
return 32_768

def _per_example_sigmoid_binary_cross_entropy(
self, logits: spec.Tensor, targets: spec.Tensor) -> spec.Tensor:
Expand Down Expand Up @@ -67,11 +66,6 @@ def loss_fn(
'per_example': per_example_losses,
}

def _eval_metric(self, logits: spec.Tensor,
targets: spec.Tensor) -> Dict[str, int]:
summed_loss = self.loss_fn(logits, targets)['summed']
return {'loss': summed_loss}

def init_model_fn(
self,
rng: spec.RandomState,
Expand Down Expand Up @@ -133,25 +127,28 @@ def model_fn(

return logits_batch, None

def _build_input_queue(self,
data_rng: jax.random.PRNGKey,
split: str,
data_dir: str,
global_batch_size: int,
num_batches: Optional[int] = None,
repeat_final_dataset: bool = False):
def _build_input_queue(
self,
data_rng: spec.RandomState,
split: str,
data_dir: str,
global_batch_size: int,
cache: Optional[bool] = None,
repeat_final_dataset: Optional[bool] = None,
num_batches: Optional[int] = None) -> Iterator[Dict[str, spec.Tensor]]:
not_train = split != 'train'
per_device_batch_size = int(global_batch_size / N_GPUS)

# Only create and iterate over tf input pipeline in one Python process to
# avoid creating too many threads.
if RANK == 0:
np_iter = super()._build_input_queue(data_rng,
split,
data_dir,
global_batch_size,
num_batches,
repeat_final_dataset)
np_iter = super()._build_input_queue(
data_rng=data_rng,
split=split,
data_dir=data_dir,
global_batch_size=global_batch_size,
num_batches=num_batches,
repeat_final_dataset=repeat_final_dataset)
weights = None
while True:
if RANK == 0:
Expand Down
19 changes: 6 additions & 13 deletions algorithmic_efficiency/workloads/criteo1tb/workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
class BaseCriteo1TbDlrmSmallWorkload(spec.Workload):
"""Criteo1tb workload."""

vocab_size: int = 32 * 128 * 1024 # 4_194_304
vocab_size: int = 32 * 128 * 1024 # 4_194_304.
num_dense_features: int = 13
mlp_bottom_dims: Tuple[int, int] = (512, 256, 128)
mlp_top_dims: Tuple[int, int, int] = (1024, 1024, 512, 256, 1)
Expand Down Expand Up @@ -128,11 +128,11 @@ def _eval_model_on_split(self,
if split not in self._eval_iters:
# These iterators will repeat indefinitely.
self._eval_iters[split] = self._build_input_queue(
rng,
split,
data_dir,
global_batch_size,
num_batches,
data_rng=rng,
split=split,
data_dir=data_dir,
global_batch_size=global_batch_size,
num_batches=num_batches,
repeat_final_dataset=True)
loss = 0.0
for _ in range(num_batches):
Expand All @@ -141,11 +141,4 @@ def _eval_model_on_split(self,
if USE_PYTORCH_DDP:
dist.all_reduce(loss)
mean_loss = loss.item() / num_examples
if FLAGS.framework == 'pytorch':
# For PyTorch, the saved iterators cause OOM after evaluation.
# Hence, we create new iterators for each evaluation step. While this
# slows down the overall time to perform evaluation, this does not affect
# the final score.
del self._eval_iters
self._eval_iters = {}
return {'loss': mean_loss}
Loading

0 comments on commit 04ffcc0

Please sign in to comment.