Skip to content

Commit

Permalink
Add llama 3 tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
sychen52 committed Nov 19, 2024
1 parent 420ed7a commit 8dfef28
Show file tree
Hide file tree
Showing 104 changed files with 7,411 additions and 232 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 16
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 128256
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/output_norm/scale: constant(1.0)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
====================weight_decay_scale root.optimizer====================
decoder/emb/token_emb/weight: 1
decoder/output_norm/scale: 1
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1
decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1
decoder/transformer/repeat/layer/feed_forward/norm/scale: 1
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1
decoder/transformer/repeat/layer/self_attention/norm/scale: 1

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/output_norm/scale: constant(1.0)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
====================weight_decay_scale root.optimizer====================
decoder/emb/token_emb/weight: 1
decoder/output_norm/scale: 1
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1
decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1
decoder/transformer/repeat/layer/feed_forward/norm/scale: 1
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1
decoder/transformer/repeat/layer/self_attention/norm/scale: 1
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 16
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 128256
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ model.decoder.transformer.num_layers: 16
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 128256
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
Expand Down
Loading

0 comments on commit 8dfef28

Please sign in to comment.