Skip to content

Commit

Permalink
Add llama 3 tokenizer
Browse files Browse the repository at this point in the history
  • Loading branch information
sychen52 committed Nov 19, 2024
1 parent 420ed7a commit e42b6f7
Show file tree
Hide file tree
Showing 107 changed files with 7,445 additions and 266 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.
model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.decoder.emb.token_emb.param_partition_spec[0]: None
model.decoder.emb.token_emb.param_partition_spec[1]: 'model'
model.decoder.eos_token_id: 1
model.decoder.eos_token_id: 128001
model.decoder.klass: 'axlearn.common.decoder.Decoder'
model.decoder.logits_partition_spec[0][0]: 'data'
model.decoder.logits_partition_spec[0][1]: 'expert'
Expand All @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.output_norm.eps: 1e-05
model.decoder.output_norm.forward_dtype: None
model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.pad_token_id: 0
model.decoder.pad_token_id: 128004
model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer'
model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu'
model.decoder.transformer.layer.feed_forward.activation[1]: 'linear'
Expand Down Expand Up @@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 16
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 128256
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/output_norm/scale: constant(1.0)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
====================weight_decay_scale root.optimizer====================
decoder/emb/token_emb/weight: 1
decoder/output_norm/scale: 1
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1
decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1
decoder/transformer/repeat/layer/feed_forward/norm/scale: 1
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1
decoder/transformer/repeat/layer/self_attention/norm/scale: 1

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(2048, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(8192, 2048), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/output_norm/scale: constant(1.0)
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
====================weight_decay_scale root.optimizer====================
decoder/emb/token_emb/weight: 1
decoder/output_norm/scale: 1
decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1
decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1
decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1
decoder/transformer/repeat/layer/feed_forward/norm/scale: 1
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1
decoder/transformer/repeat/layer/self_attention/norm/scale: 1
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.
model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.decoder.emb.token_emb.param_partition_spec[0]: None
model.decoder.emb.token_emb.param_partition_spec[1]: 'model'
model.decoder.eos_token_id: 1
model.decoder.eos_token_id: 128001
model.decoder.klass: 'axlearn.common.decoder.Decoder'
model.decoder.logits_partition_spec[0][0]: 'data'
model.decoder.logits_partition_spec[0][1]: 'expert'
Expand All @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.output_norm.eps: 1e-05
model.decoder.output_norm.forward_dtype: None
model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.pad_token_id: 0
model.decoder.pad_token_id: 128004
model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer'
model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu'
model.decoder.transformer.layer.feed_forward.activation[1]: 'linear'
Expand Down Expand Up @@ -267,7 +267,7 @@ model.decoder.transformer.num_layers: 16
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 128256
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.
model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer'
model.decoder.emb.token_emb.param_partition_spec[0]: None
model.decoder.emb.token_emb.param_partition_spec[1]: 'model'
model.decoder.eos_token_id: 1
model.decoder.eos_token_id: 128001
model.decoder.klass: 'axlearn.common.decoder.Decoder'
model.decoder.logits_partition_spec[0][0]: 'data'
model.decoder.logits_partition_spec[0][1]: 'expert'
Expand All @@ -148,7 +148,7 @@ model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout'
model.decoder.output_norm.eps: 1e-05
model.decoder.output_norm.forward_dtype: None
model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm'
model.decoder.pad_token_id: 0
model.decoder.pad_token_id: 128004
model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer'
model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu'
model.decoder.transformer.layer.feed_forward.activation[1]: 'linear'
Expand Down Expand Up @@ -232,7 +232,7 @@ model.decoder.transformer.num_layers: 16
model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex'
model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*'
model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat'
model.decoder.vocab_size: 128256
model.decoder.vocab_size: 131072
model.dtype: 'jax.numpy.float32'
model.klass: 'axlearn.common.causal_lm.Model'
model.param_init.init_by_param_name['.*weight$'].distribution: 'normal'
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 2048], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=())
decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0)
decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 48, 64), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=())
decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(2048, 32, 64), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=())
Expand Down
Loading

0 comments on commit e42b6f7

Please sign in to comment.