forked from mlcommons/algorithmic-efficiency
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_num_params.py
165 lines (154 loc) · 6.2 KB
/
test_num_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import jax
import jax.numpy as jnp
import jax.random as jax_rng
import jraph
import pytest
import torch
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_jax.models import \
DlrmSmall as JaxDlrmSmall
from algorithmic_efficiency.workloads.criteo1tb.criteo1tb_pytorch.models import \
DlrmSmall as PyTorchDlrmSmall
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.models import \
ResNet18 as JaxResNet_c10
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_jax.models import \
ResNet50 as JaxResNet
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \
resnet18 as PyTorchResNet_c10
from algorithmic_efficiency.workloads.imagenet_resnet.imagenet_pytorch.models import \
resnet50 as PyTorchResNet
from algorithmic_efficiency.workloads.imagenet_vit.imagenet_jax.models import \
ViT as JaxViT
from algorithmic_efficiency.workloads.imagenet_vit.imagenet_pytorch.models import \
ViT as PyTorchViT
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.models import \
Conformer as JaxConformer
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_jax.models import \
ConformerConfig as JaxConformerConfig
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \
ConformerConfig as PytorchConformerConfig
from algorithmic_efficiency.workloads.librispeech_conformer.librispeech_pytorch.models import \
ConformerEncoderDecoder as PytorchConformer
from algorithmic_efficiency.workloads.mnist.mnist_jax.workload import \
_Model as JaxMLP
from algorithmic_efficiency.workloads.mnist.mnist_pytorch.workload import \
_Model as PyTorchMLP
from algorithmic_efficiency.workloads.ogbg.ogbg_jax.models import GNN as JaxGNN
from algorithmic_efficiency.workloads.ogbg.ogbg_pytorch.models import \
GNN as PyTorchGNN
from algorithmic_efficiency.workloads.wmt.wmt_jax.models import \
Transformer as JaxTransformer
from algorithmic_efficiency.workloads.wmt.wmt_jax.models import \
TransformerConfig
from algorithmic_efficiency.workloads.wmt.wmt_pytorch.models import \
Transformer as PyTorchTransformer
WORKLOADS = [
'mnist',
'cifar',
'criteo1tb',
'imagenet_resnet',
'imagenet_vit',
'wmt',
'ogbg',
'librispeech_conformer',
]
@pytest.mark.parametrize('workload', WORKLOADS)
def test_matching_num_params(workload):
jax_model, pytorch_model = get_models(workload)
# Count parameters of both models.
num_jax_params = sum(x.size for x in jax.tree_util.tree_leaves(jax_model))
num_pytorch_params = sum(
p.numel() for p in pytorch_model.parameters() if p.requires_grad)
assert num_jax_params == num_pytorch_params
def get_models(workload):
init_rngs = {'params': jax_rng.PRNGKey(0), 'dropout': jax_rng.PRNGKey(1)}
if workload == 'mnist':
# Init Jax model.
init_val = jnp.ones((1, 28, 28, 1), jnp.float32)
jax_model = JaxMLP().init(init_rngs, init_val, train=True)['params']
# Init PyTorch model.
pytorch_model = PyTorchMLP()
elif workload == 'cifar':
# Init Jax model.
input_shape = (1, 32, 32, 3)
model_init = jax.jit(JaxResNet_c10(num_classes=10, dtype=jnp.float32).init)
jax_model = model_init(init_rngs, jnp.ones(input_shape,
jnp.float32))["params"]
# Init PyTorch model.
pytorch_model = PyTorchResNet_c10(num_classes=10)
elif workload == 'criteo1tb':
# Init Jax model.
mlp_bottom_dims = (512, 256, 128)
mlp_top_dims = (1024, 1024, 512, 256, 1)
embed_dim = 128
vocab_size = 32 * 128 * 1024
input_shape = (1, 39)
model_init = JaxDlrmSmall(
vocab_size=vocab_size,
num_dense_features=13,
mlp_bottom_dims=mlp_bottom_dims,
mlp_top_dims=mlp_top_dims,
embed_dim=embed_dim).init
jax_model = model_init(init_rngs, jnp.ones(input_shape, jnp.float32),
False)['params']
# Init PyTorch model.
pytorch_model = PyTorchDlrmSmall(
vocab_size=vocab_size,
num_dense_features=13,
mlp_bottom_dims=mlp_bottom_dims,
mlp_top_dims=mlp_top_dims,
embed_dim=embed_dim)
elif workload == 'imagenet_resnet':
# Init Jax model.
input_shape = (1, 224, 224, 3)
jax_model = JaxResNet(
num_classes=1000,
dtype=jnp.float32).init(init_rngs, jnp.ones(input_shape,
jnp.float32))['params']
# Init PyTorch model.
pytorch_model = PyTorchResNet()
elif workload == 'imagenet_vit':
# Init Jax model.
input_shape = (1, 224, 224, 3)
jax_model = JaxViT(num_classes=1000).init(
init_rngs, jnp.ones(input_shape, jnp.float32))['params']
# Init PyTorch model.
pytorch_model = PyTorchViT()
elif workload == 'librispeech_conformer':
jax_model = JaxConformer(JaxConformerConfig())
pytorch_model = PytorchConformer(PytorchConformerConfig())
# Init Jax model
input_shape = [(320000,), (320000,)]
fake_input_batch = [jnp.zeros((2, *x), jnp.float32) for x in input_shape]
jax_model = jax_model.init(
init_rngs, train=False, *fake_input_batch)["params"]
# Run model once to initialize lazy layers
wave = torch.randn(2, 320000)
pad = torch.zeros_like(wave)
pytorch_model(wave, pad)
elif workload == 'wmt':
# Init Jax model.
input_shape = (16, 256)
target_shape = (16, 256)
jax_model = JaxTransformer(TransformerConfig).init(
init_rngs,
jnp.ones(input_shape, jnp.float32),
jnp.ones(target_shape, jnp.float32))['params']
# Init PyTorch model.
pytorch_model = PyTorchTransformer()
elif workload == 'ogbg':
# Init Jax model.
fake_batch = jraph.GraphsTuple(
n_node=jnp.asarray([1]),
n_edge=jnp.asarray([1]),
nodes=jnp.ones((1, 9)),
edges=jnp.ones((1, 3)),
globals=jnp.zeros((1, 128)),
senders=jnp.asarray([0]),
receivers=jnp.asarray([0]))
jax_model = JaxGNN(num_outputs=128).init(
init_rngs, fake_batch, train=False)['params']
# Init PyTorch model.
pytorch_model = PyTorchGNN(num_outputs=128)
else:
raise ValueError(f'Models for workload {workload} are not available.')
return jax_model, pytorch_model