Skip to content

Commit

Permalink
Add extra_repr to Linear classes for debugging purpose (#6954)
Browse files Browse the repository at this point in the history
**Summary**
This PR adds `extra_repr` method to some Linear classes so that
additional info is printed when printing such modules. It is useful for
debugging.
Affected modules:
- LinearLayer
- LinearAllreduce
- LmHeadLinearAllreduce

The `extra_repr` method gives the following info:
- in_features
- out_features
- bias (true or false)
- dtype

**Example**
Print llama-2-7b model on rank 0 after `init_inference` with world size
= 2.
Previously we only got class names of these modules:
```
InferenceEngine(
  (module): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(32000, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): LinearLayer()
            (k_proj): LinearLayer()
            (v_proj): LinearLayer()
            (o_proj): LinearAllreduce()
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): LinearLayer()
            (up_proj): LinearLayer()
            (down_proj): LinearAllreduce()
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((4096,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (lm_head): LmHeadLinearAllreduce()
  )
)
```
Now we get more useful info:
```
InferenceEngine(
  (module): LlamaForCausalLM(
    (model): LlamaModel(
      (embed_tokens): Embedding(32000, 4096)
      (layers): ModuleList(
        (0-31): 32 x LlamaDecoderLayer(
          (self_attn): LlamaSdpaAttention(
            (q_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16)
            (k_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16)
            (v_proj): LinearLayer(in_features=4096, out_features=2048, bias=False, dtype=torch.bfloat16)
            (o_proj): LinearAllreduce(in_features=2048, out_features=4096, bias=False, dtype=torch.bfloat16)
            (rotary_emb): LlamaRotaryEmbedding()
          )
          (mlp): LlamaMLP(
            (gate_proj): LinearLayer(in_features=4096, out_features=5504, bias=False, dtype=torch.bfloat16)
            (up_proj): LinearLayer(in_features=4096, out_features=5504, bias=False, dtype=torch.bfloat16)
            (down_proj): LinearAllreduce(in_features=5504, out_features=4096, bias=False, dtype=torch.bfloat16)
            (act_fn): SiLU()
          )
          (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
          (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        )
      )
      (norm): LlamaRMSNorm((4096,), eps=1e-05)
      (rotary_emb): LlamaRotaryEmbedding()
    )
    (lm_head): LmHeadLinearAllreduce(in_features=2048, out_features=32000, bias=False, dtype=torch.bfloat16)
  )
)
```
  • Loading branch information
Xia-Weiwen authored Jan 16, 2025
1 parent 05eaf3d commit 018ece5
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions deepspeed/module_inject/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def forward(self, input):
output += self.bias
return output

def extra_repr(self):
out_features, in_features = self.weight.shape if self.weight is not None else (None, None)
dtype = self.weight.dtype if self.weight is not None else None
extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str


class LmHeadLinearAllreduce(nn.Module):

Expand Down Expand Up @@ -120,6 +127,13 @@ def forward(self, input):
output += self.bias
return output

def extra_repr(self):
out_features, in_features = self.weight.shape if self.weight is not None else (None, None)
dtype = self.weight.dtype if self.weight is not None else None
extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str


class LinearLayer(nn.Module):

Expand All @@ -144,6 +158,13 @@ def forward(self, input):
output += self.bias
return output

def extra_repr(self):
out_features, in_features = self.weight.shape
dtype = self.weight.dtype
extra_repr_str = "in_features={}, out_features={}, bias={}, dtype={}".format(
in_features, out_features, self.bias is not None, dtype)
return extra_repr_str


class Normalize(nn.Module):

Expand Down

0 comments on commit 018ece5

Please sign in to comment.