Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Criteo workload variants #568

Merged
merged 161 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
161 commits
Select commit Hold shift + click to select a range
3933a34
criteo variants
priyakasimbeg Nov 10, 2023
5f40758
add workload variants
priyakasimbeg Nov 10, 2023
38821bb
add criteo variants to workload registry
priyakasimbeg Nov 11, 2023
f28ecbc
workload registry
priyakasimbeg Nov 11, 2023
3e44f69
add regression tests
priyakasimbeg Nov 11, 2023
a3a827c
add empty line
priyakasimbeg Nov 11, 2023
e1e9b4e
formatting
priyakasimbeg Nov 11, 2023
d3334f5
regression test fix
priyakasimbeg Nov 13, 2023
b6443f2
add criteo_variants to valid workloads to startup.sh"
priyakasimbeg Nov 13, 2023
4805423
syntax fix
priyakasimbeg Nov 14, 2023
b5c1979
add bsz for criteo1tb variants
priyakasimbeg Nov 14, 2023
87a0425
change regression test name for variants
priyakasimbeg Nov 14, 2023
6e6e786
syntax fix
priyakasimbeg Nov 14, 2023
dc208d0
lint fix
priyakasimbeg Nov 14, 2023
2ee0f77
formatting
priyakasimbeg Nov 14, 2023
fa75663
fix
priyakasimbeg Nov 14, 2023
d21f05b
fix typo
priyakasimbeg Nov 14, 2023
8776748
fix resnet
priyakasimbeg Nov 15, 2023
6c8fbe9
modify regresion test
priyakasimbeg Nov 15, 2023
7b7e750
modify_test
priyakasimbeg Nov 15, 2023
e601db4
add tests
priyakasimbeg Nov 16, 2023
7a71040
Merge branch 'dev' into criteo_workload_variants
priyakasimbeg Nov 16, 2023
8d2e828
regression tests
priyakasimbeg Nov 16, 2023
2193830
remove variant bsz
priyakasimbeg Nov 16, 2023
d46d5e0
add helper fn for get_baseworkload_name
priyakasimbeg Nov 16, 2023
7ddffd3
fix
priyakasimbeg Nov 16, 2023
6b2ed1f
modify conformer_resnet model
priyakasimbeg Nov 16, 2023
cd2d672
ln
priyakasimbeg Nov 16, 2023
f0a369a
fix
priyakasimbeg Nov 16, 2023
01d668c
fix
priyakasimbeg Nov 16, 2023
910e974
fix
priyakasimbeg Nov 16, 2023
063c229
debugging
priyakasimbeg Nov 16, 2023
2bf1988
fix
priyakasimbeg Nov 16, 2023
2c2a7a9
debugging
priyakasimbeg Nov 16, 2023
e698bc7
resnet block
priyakasimbeg Nov 17, 2023
9fd4efa
add resnet block to criteo resnet variant
priyakasimbeg Nov 17, 2023
bc78a19
add back dims
priyakasimbeg Nov 17, 2023
b1d2224
fix
priyakasimbeg Nov 17, 2023
6a46a6c
resnet
priyakasimbeg Nov 17, 2023
83c4ade
comment out pytorch
priyakasimbeg Nov 17, 2023
e60a159
resnet fix'
priyakasimbeg Nov 17, 2023
e95fe17
fix
priyakasimbeg Nov 17, 2023
1e5c709
fix
priyakasimbeg Nov 17, 2023
04da955
debugging
priyakasimbeg Nov 17, 2023
0b3c7a9
debugging
priyakasimbeg Nov 17, 2023
76169a2
mlp dims
priyakasimbeg Nov 17, 2023
cf7f221
variant fix
priyakasimbeg Nov 17, 2023
6fe0901
fix dlrm variant
priyakasimbeg Nov 17, 2023
0723d82
dlrm fix
priyakasimbeg Nov 17, 2023
d56269b
debugging
priyakasimbeg Nov 17, 2023
61f052b
debugging
priyakasimbeg Nov 17, 2023
ff5aa91
debug
priyakasimbeg Nov 17, 2023
df48937
debug
priyakasimbeg Nov 17, 2023
8a574e3
debug
priyakasimbeg Nov 17, 2023
82e42e3
debug
priyakasimbeg Nov 17, 2023
f2beed9
debug
priyakasimbeg Nov 17, 2023
17c1190
fix
priyakasimbeg Nov 17, 2023
e5a31f5
debug
priyakasimbeg Nov 17, 2023
7c98352
debuggingg
priyakasimbeg Nov 17, 2023
c9facb2
debugging
priyakasimbeg Nov 17, 2023
549be97
debugging
priyakasimbeg Nov 17, 2023
6967445
debug
priyakasimbeg Nov 17, 2023
5d237eb
debug
priyakasimbeg Nov 17, 2023
9a132b1
debugging
priyakasimbeg Nov 17, 2023
89dea72
debugging
priyakasimbeg Nov 18, 2023
2c54700
Merge branch 'criteo_workload_variants' of github.com:mlcommons/algor…
priyakasimbeg Nov 18, 2023
fdca6c5
debugging
priyakasimbeg Nov 18, 2023
29835e3
debugging
priyakasimbeg Nov 18, 2023
3a859eb
clarify output of diff test
priyakasimbeg Nov 18, 2023
7829863
key transform
priyakasimbeg Nov 18, 2023
a47e6b1
syntaxl
priyakasimbeg Nov 18, 2023
f75944d
diff
priyakasimbeg Nov 18, 2023
ef68ba8
logging
priyakasimbeg Nov 18, 2023
e096a3a
debug
priyakasimbeg Nov 18, 2023
0c7a864
debug
priyakasimbeg Nov 18, 2023
271412d
debug
priyakasimbeg Nov 18, 2023
650ef43
debugging
priyakasimbeg Nov 18, 2023
ed1cdba
debug
priyakasimbeg Nov 18, 2023
f59b48d
debug
priyakasimbeg Nov 18, 2023
47e7166
debug
priyakasimbeg Nov 18, 2023
39a6c31
debug
priyakasimbeg Nov 18, 2023
f70dd49
debug
priyakasimbeg Nov 18, 2023
0966015
debug
priyakasimbeg Nov 18, 2023
d7d4638
remove some debugging statements
priyakasimbeg Nov 18, 2023
b1c35d4
add debugging statement
priyakasimbeg Nov 18, 2023
e7e52f0
resnet fix
priyakasimbeg Nov 20, 2023
55d72d8
fix
priyakasimbeg Nov 20, 2023
ca8a00a
debugging
priyakasimbeg Nov 20, 2023
d929d76
compare_fix
priyakasimbeg Nov 20, 2023
5192df7
fix
priyakasimbeg Nov 20, 2023
a88f516
fix
priyakasimbeg Nov 20, 2023
7aca234
fix
priyakasimbeg Nov 20, 2023
c2d288e
block count
priyakasimbeg Nov 21, 2023
1bb484f
fix
priyakasimbeg Nov 21, 2023
b50e9dd
fix resnet jax
priyakasimbeg Nov 21, 2023
c83bdad
remove debugging statemetns
priyakasimbeg Nov 21, 2023
a6f2ba0
fix logging
priyakasimbeg Nov 21, 2023
b641ab9
add back print statemetns
priyakasimbeg Nov 21, 2023
62a43c8
resnet fix
priyakasimbeg Nov 21, 2023
84463e2
change block structures
priyakasimbeg Nov 21, 2023
6385365
fix
priyakasimbeg Nov 21, 2023
a2b64e9
debugging
priyakasimbeg Nov 22, 2023
ce7fcd9
debug
priyakasimbeg Nov 22, 2023
26f7864
fix
priyakasimbeg Nov 22, 2023
8243820
debug
priyakasimbeg Nov 22, 2023
c44b5bd
debug
priyakasimbeg Nov 22, 2023
6913e92
fix
priyakasimbeg Nov 22, 2023
6fc1830
debugigng
priyakasimbeg Nov 22, 2023
a140fc0
fix
priyakasimbeg Nov 22, 2023
9550c15
debug
priyakasimbeg Nov 22, 2023
bcf68ed
fix test
priyakasimbeg Nov 22, 2023
85be241
remove debugging statements
priyakasimbeg Nov 22, 2023
987686c
add jax model summary helper fn
priyakasimbeg Nov 22, 2023
c6e26da
formatting
priyakasimbeg Nov 22, 2023
90e8c80
update jax criteo workload
priyakasimbeg Nov 22, 2023
e440208
criteo workload variant
priyakasimbeg Nov 22, 2023
626420c
fixes
priyakasimbeg Nov 22, 2023
e070fe3
debug
priyakasimbeg Nov 22, 2023
412e4fd
debug
priyakasimbeg Nov 22, 2023
c982326
debug
priyakasimbeg Nov 22, 2023
f960a0a
debugging
priyakasimbeg Nov 22, 2023
122fbf4
debugging
priyakasimbeg Nov 22, 2023
bf12d5b
shape debugging
priyakasimbeg Nov 22, 2023
a234a8a
debugging
priyakasimbeg Nov 22, 2023
456573a
debugging
priyakasimbeg Nov 22, 2023
12f5da6
debugging
priyakasimbeg Nov 22, 2023
438deaf
debugging
priyakasimbeg Nov 22, 2023
49b45f6
debug
priyakasimbeg Nov 22, 2023
3ed3633
criteo variants
priyakasimbeg Nov 22, 2023
834ac9c
debugging
priyakasimbeg Nov 22, 2023
e4c7b34
fix
priyakasimbeg Nov 22, 2023
ac6baf7
embedding initialization criteo
priyakasimbeg Nov 22, 2023
e138c83
test embedding init
priyakasimbeg Nov 22, 2023
132223b
fix
priyakasimbeg Nov 22, 2023
071ddf3
add embedding init multiplier
priyakasimbeg Nov 22, 2023
134d0bb
debugging
priyakasimbeg Nov 22, 2023
955b70d
test
priyakasimbeg Nov 22, 2023
dadd4d6
fix
priyakasimbeg Nov 22, 2023
8e832a3
fix
priyakasimbeg Nov 22, 2023
07daeee
debug
priyakasimbeg Nov 22, 2023
263108f
debug
priyakasimbeg Nov 22, 2023
a7dd161
fix
priyakasimbeg Nov 22, 2023
40a4cde
debugging
priyakasimbeg Nov 22, 2023
1740325
debugging
priyakasimbeg Nov 22, 2023
b05f510
fix
priyakasimbeg Nov 22, 2023
3cf57a7
add tests
priyakasimbeg Nov 22, 2023
123dec9
Merge branch 'dev' into criteo_workload_variants
priyakasimbeg Nov 22, 2023
a5568ed
clean up;
priyakasimbeg Nov 22, 2023
76d1749
formatting
priyakasimbeg Nov 22, 2023
284c30c
reformat
priyakasimbeg Nov 22, 2023
547e225
sorting imports
priyakasimbeg Nov 22, 2023
a37d3fa
fix
priyakasimbeg Nov 22, 2023
1af2a8a
remove unused imports
priyakasimbeg Nov 22, 2023
72ede87
pylint
priyakasimbeg Nov 22, 2023
f6025a2
add exception
priyakasimbeg Nov 22, 2023
f947e23
add clarifying docs for submission example
priyakasimbeg Nov 22, 2023
765c363
formatting
priyakasimbeg Nov 22, 2023
9da4996
make exception specific
priyakasimbeg Nov 22, 2023
96bcea6
formatting
priyakasimbeg Nov 22, 2023
59384c6
pylint
priyakasimbeg Nov 22, 2023
c07e0e7
fixes
priyakasimbeg Dec 7, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions .github/workflows/regression_tests_variants.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
name: Containerized Regression Tests for Workload Variants

on:
pull_request:
branches:
- 'main'

jobs:
build_and_push_jax_docker_image:
runs-on: self-hosted
steps:
- uses: actions/checkout@v2
- name: Build and push docker images
run: |
GIT_BRANCH=${{ github.head_ref || github.ref_name }}
FRAMEWORK=jax
IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}"
cd $HOME/algorithmic-efficiency/docker
docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH
BUILD_RETURN=$?
if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi
docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
build_and_push_pytorch_docker_image:
runs-on: self-hosted
steps:
- uses: actions/checkout@v2
- name: Build and push docker images
run: |
GIT_BRANCH=${{ github.head_ref || github.ref_name }}
FRAMEWORK=pytorch
IMAGE_NAME="algoperf_${FRAMEWORK}_${GIT_BRANCH}"
cd $HOME/algorithmic-efficiency/docker
docker build --no-cache -t $IMAGE_NAME . --build-arg framework=$FRAMEWORK --build-arg branch=$GIT_BRANCH
BUILD_RETURN=$?
if [[ ${BUILD_RETURN} != 0 ]]; then exit ${BUILD_RETURN}; fi
docker tag $IMAGE_NAME us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
docker push us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/$IMAGE_NAME
criteo_layernorm_jax:
runs-on: self-hosted
needs: build_and_push_jax_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized workload
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false
criteo_resnet_jax:
runs-on: self-hosted
needs: build_and_push_jax_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized workload
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_jax_${{ github.head_ref || github.ref_name }} -d criteo1tb -f jax -s baselines/adamw/jax/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false
criteo_layernorm_pytorch:
runs-on: self-hosted
needs: build_and_push_pytorch_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized workload
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_layernorm -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false
criteo_resnet_pytorch:
runs-on: self-hosted
needs: build_and_push_pytorch_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized workload
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_resnet -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false
criteo_resnet_pytorch:
runs-on: self-hosted
needs: build_and_push_pytorch_docker_image
steps:
- uses: actions/checkout@v2
- name: Run containerized workload
run: |
docker pull us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }}
docker run -v $HOME/data/:/data/ -v $HOME/experiment_runs/:/experiment_runs -v $HOME/experiment_runs/logs:/logs --gpus all --ipc=host us-central1-docker.pkg.dev/training-algorithms-external/mlcommons-docker-repo/algoperf_pytorch_${{ github.head_ref || github.ref_name }} -d criteo1tb -f pytorch -s baselines/adamw/pytorch/submission.py -w criteo1tb_embed_init -t baselines/adamw/tuning_search_space.json -e tests/regression_tests/adamw -m 10 -c False -o True -r false


111 changes: 109 additions & 2 deletions algorithmic_efficiency/workloads/criteo1tb/criteo1tb_jax/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,101 @@
import jax.numpy as jnp


class DLRMResNet(nn.Module):
"""Define a DLRMResNet model.

Parameters:
vocab_size: the size of a single unified embedding table.
mlp_bottom_dims: dimensions of dense layers of the bottom mlp.
mlp_top_dims: dimensions of dense layers of the top mlp.
num_dense_features: number of dense features as the bottom mlp input.
embed_dim: embedding dimension.
"""

vocab_size: int = 32 * 128 * 1024 # 4_194_304
num_dense_features: int = 13
mlp_bottom_dims: Sequence[int] = (256, 256, 256)
mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
embed_dim: int = 128
dropout_rate: float = 0.0
use_layer_norm: bool = False # Unused.
embedding_init_multiplier: float = None # Unused

@nn.compact
def __call__(self, x, train):
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)

# bottom mlp
mlp_bottom_dims = self.mlp_bottom_dims

bot_mlp_input = nn.Dense(
mlp_bottom_dims[0],
kernel_init=jnn.initializers.glorot_uniform(),
bias_init=jnn.initializers.normal(stddev=1.0 / mlp_bottom_dims[0]**0.5),
)(
bot_mlp_input)
bot_mlp_input = nn.relu(bot_mlp_input)

for dense_dim in mlp_bottom_dims[1:]:
x = nn.Dense(
dense_dim,
kernel_init=jnn.initializers.glorot_uniform(),
bias_init=jnn.initializers.normal(stddev=1.0 / dense_dim**0.5),
)(
bot_mlp_input)
bot_mlp_input += nn.relu(x)

base_init_fn = jnn.initializers.uniform(scale=1.0)
# Embedding table init and lookup for a single unified table.
idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size

def scaled_init(key, shape, dtype=jnp.float_):
return base_init_fn(key, shape, dtype) / jnp.sqrt(self.vocab_size)

embedding_table = self.param('embedding_table',
scaled_init, [self.vocab_size, self.embed_dim])

embed_features = embedding_table[idx_lookup]
batch_size = bot_mlp_input.shape[0]
embed_features = jnp.reshape(embed_features,
(batch_size, 26 * self.embed_dim))
top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1)
mlp_input_dim = top_mlp_input.shape[1]
mlp_top_dims = self.mlp_top_dims
num_layers_top = len(mlp_top_dims)
top_mlp_input = nn.Dense(
mlp_top_dims[0],
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (mlp_input_dim + mlp_top_dims[0]))),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0 / mlp_top_dims[0])))(
top_mlp_input)
top_mlp_input = nn.relu(top_mlp_input)
for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]:
fan_in = mlp_top_dims[layer_idx - 1]
x = nn.Dense(
fan_out,
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (fan_in + fan_out))),
bias_init=jnn.initializers.normal(
stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
top_mlp_input)
x = nn.relu(x)
if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2:
x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
top_mlp_input += x
# In the DLRM model the last layer width is always 1. We can hardcode that
# below.
logits = nn.Dense(
1,
kernel_init=jnn.initializers.normal(
stddev=jnp.sqrt(2.0 / (mlp_top_dims[-2] + 1))),
bias_init=jnn.initializers.normal(stddev=jnp.sqrt(1.0)))(
top_mlp_input)
return logits


def dot_interact(concat_features):
"""Performs feature interaction operation between dense or sparse features.
Input tensors represent dense or sparse features.
Expand Down Expand Up @@ -52,6 +147,8 @@ class DlrmSmall(nn.Module):
mlp_top_dims: Sequence[int] = (1024, 1024, 512, 256, 1)
embed_dim: int = 128
dropout_rate: float = 0.0
use_layer_norm: bool = False
embedding_init_multiplier: float = None

@nn.compact
def __call__(self, x, train):
Expand All @@ -67,6 +164,8 @@ def __call__(self, x, train):
)(
bot_mlp_input)
bot_mlp_input = nn.relu(bot_mlp_input)
if self.use_layer_norm:
bot_mlp_input = nn.LayerNorm()(bot_mlp_input)
bot_mlp_output = bot_mlp_input
batch_size = bot_mlp_output.shape[0]
feature_stack = jnp.reshape(bot_mlp_output,
Expand All @@ -75,9 +174,13 @@ def __call__(self, x, train):
# Embedding table look-up.
idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size

if self.embedding_init_multiplier is None:
scale = 1 / jnp.sqrt(self.vocab_size)
else:
scale = self.embedding_init_multiplier

def scaled_init(key, shape, dtype=jnp.float_):
return (jnn.initializers.uniform(scale=1.0)(key, shape, dtype) /
jnp.sqrt(self.vocab_size))
return jnn.initializers.uniform(scale=1.0)(key, shape, dtype) * scale

embedding_table = self.param('embedding_table',
scaled_init, [self.vocab_size, self.embed_dim])
Expand All @@ -86,6 +189,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
embed_features = embedding_table[idx_lookup]
embed_features = jnp.reshape(embed_features,
[batch_size, -1, self.embed_dim])
if self.use_layer_norm:
embed_features = nn.LayerNorm()(embed_features)
feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1)
dot_interact_output = dot_interact(concat_features=feature_stack)
top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output],
Expand All @@ -103,6 +208,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
top_mlp_input)
if layer_idx < (num_layers_top - 1):
top_mlp_input = nn.relu(top_mlp_input)
if self.use_layer_norm:
top_mlp_input = nn.LayerNorm()(top_mlp_input)
if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
layer_idx == num_layers_top - 2):
top_mlp_input = nn.Dropout(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,23 +73,31 @@ def init_model_fn(
self,
rng: spec.RandomState,
dropout_rate: Optional[float] = None,
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
aux_dropout_rate: Optional[float] = None,
tabulate: Optional[bool] = False,
priyakasimbeg marked this conversation as resolved.
Show resolved Hide resolved
) -> spec.ModelInitState:
"""Only dropout is used."""
del aux_dropout_rate
self._model = models.DlrmSmall(
if self.use_resnet:
model_class = models.DLRMResNet
else:
model_class = models.DlrmSmall
self._model = model_class(
vocab_size=self.vocab_size,
num_dense_features=self.num_dense_features,
mlp_bottom_dims=self.mlp_bottom_dims,
mlp_top_dims=self.mlp_top_dims,
embed_dim=self.embed_dim,
dropout_rate=dropout_rate)
dropout_rate=dropout_rate,
use_layer_norm=self.use_layer_norm,
embedding_init_multiplier=self.embedding_init_multiplier)

params_rng, dropout_rng = jax.random.split(rng)
init_fake_batch_size = 2
num_categorical_features = 26
input_size = self.num_dense_features + num_categorical_features
num_dense_features = 13
input_size = num_dense_features + num_categorical_features
input_shape = (init_fake_batch_size, input_size)

init_fn = functools.partial(self._model.init, train=False)
initial_variables = jax.jit(init_fn)(
{'params': params_rng, 'dropout': dropout_rng},
Expand Down Expand Up @@ -154,3 +162,53 @@ def _eval_batch(self,

class Criteo1TbDlrmSmallTestWorkload(Criteo1TbDlrmSmallWorkload):
vocab_size: int = 32 * 128 * 16


class Criteo1TbDlrmSmallLayerNormWorkload(Criteo1TbDlrmSmallWorkload):

@property
def use_layer_norm(self) -> bool:
"""Whether or not to use LayerNorm in the model."""
return True

@property
def validation_target_value(self) -> float:
return 0.123744

@property
def test_target_value(self) -> float:
return 0.126152


class Criteo1TbDlrmSmallResNetWorkload(Criteo1TbDlrmSmallWorkload):
mlp_bottom_dims: Tuple[int, int] = (256, 256, 256)
mlp_top_dims: Tuple[int, int, int] = (256, 256, 256, 256, 1)

@property
def use_resnet(self) -> bool:
"""Whether or not to use residual connections in the model."""
return True

@property
def validation_target_value(self) -> float:
return 0.124027

@property
def test_target_value(self) -> float:
return 0.126468


class Criteo1TbDlrmSmallEmbedInitWorkload(Criteo1TbDlrmSmallWorkload):

@property
def validation_target_value(self) -> float:
return 0.124286

@property
def test_target_value(self) -> float:
# Todo
return 0.126725

@property
def embedding_init_multiplier(self) -> float:
return 1.0
Loading
Loading