Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add extra_repr to Linear classes for debugging purpose (#6954)
**Summary** This PR adds `extra_repr` method to some Linear classes so that additional info is printed when printing such modules. It is useful for debugging. Affected modules: - LinearLayer - LinearAllreduce - LmHeadLinearAllreduce The `extra_repr` method gives the following info: - in_features - out_features - bias (true or false) - dtype **Example** Print llama-2-7b model on rank 0 after `init_inference` with world size = 2. Previously we only got class names of these modules: ``` InferenceEngine( (module): LlamaForCausalLM( (model): LlamaModel( (embed_tokens): Embedding(32000, 4096) (layers): ModuleList( (0-31): 32 x LlamaDecoderLayer( (self_attn): LlamaSdpaAttention( (q_proj): LinearLayer() (k_proj): LinearLayer() (v_proj): LinearLayer() (o_proj): LinearAllreduce() (rotary_emb): LlamaRotaryEmbedding() ) (mlp): LlamaMLP( (gate_proj): LinearLayer() (up_proj): LinearLayer() (down_proj): LinearAllreduce() (act_fn): SiLU() ) (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05) (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05) ) ) (norm): LlamaRMSNorm((4096,), eps=1e-05) (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): LmHeadLinearAllreduce() ) ) ``` Now we get more useful info: ``` InferenceEngine( (module): LlamaForCausalLM( (model): LlamaModel( (embed_tokens): Embedding(32000, 4096) (layers): ModuleList( (0-31): 32 x LlamaDecoderLayer( (self_attn): LlamaSdpaAttention( (q_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16) (k_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16) (v_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16) (o_proj): LinearAllreduce(in_features=2048, out_features=4096, bias=False, dtype=torch.bfloat16) (rotary_emb): LlamaRotaryEmbedding() ) (mlp): LlamaMLP( (gate_proj): LinearLayer(in_features=4096, out_features=5504, bias=False, dtype=torch.bfloat16) (up_proj): LinearLayer(in_features=4096, out_features=5504, bias=False, dtype=torch.bfloat16) (down_proj): LinearAllreduce(in_features=5504, out_features=4096, bias=False, dtype=torch.bfloat16) (act_fn): SiLU() ) (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05) (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05) ) ) (norm): LlamaRMSNorm((4096,), eps=1e-05) (rotary_emb): LlamaRotaryEmbedding() ) (lm_head): LmHeadLinearAllreduce(in_features=2048, out_features=32000, bias=False, dtype=torch.bfloat16) ) ) ```
- Loading branch information