diff --git a/build/lib/protein_lm/__init__.py b/build/lib/protein_lm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/protein_lm/dataset/__init__.py b/build/lib/protein_lm/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/protein_lm/evaluation/__init__.py b/build/lib/protein_lm/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build/lib/protein_lm/modeling/__init__.py b/build/lib/protein_lm/modeling/__init__.py new file mode 100644 index 0000000..173d567 --- /dev/null +++ b/build/lib/protein_lm/modeling/__init__.py @@ -0,0 +1 @@ +from models import * \ No newline at end of file diff --git a/protein_lm.yml b/protein_lm.yml index 5ce09c6..aba06ef 100644 --- a/protein_lm.yml +++ b/protein_lm.yml @@ -20,3 +20,4 @@ dependencies: - evaluate - pytest - fair-esm + - mup diff --git a/protein_lm/modeling/__init__.py b/protein_lm/modeling/__init__.py index e69de29..173d567 100644 --- a/protein_lm/modeling/__init__.py +++ b/protein_lm/modeling/__init__.py @@ -0,0 +1 @@ +from models import * \ No newline at end of file diff --git a/protein_lm/modeling/models/__init__.py b/protein_lm/modeling/models/__init__.py index e69de29..20b24e2 100644 --- a/protein_lm/modeling/models/__init__.py +++ b/protein_lm/modeling/models/__init__.py @@ -0,0 +1 @@ +from apt import * \ No newline at end of file diff --git a/protein_lm/modeling/models/apt/__init__.py b/protein_lm/modeling/models/apt/__init__.py index e69de29..a098708 100644 --- a/protein_lm/modeling/models/apt/__init__.py +++ b/protein_lm/modeling/models/apt/__init__.py @@ -0,0 +1,2 @@ +from config import APTConfig +from model_pytorch import * \ No newline at end of file diff --git a/protein_lm/modeling/models/apt/config.py b/protein_lm/modeling/models/apt/config.py index 36f2c04..9d592d2 100644 --- a/protein_lm/modeling/models/apt/config.py +++ b/protein_lm/modeling/models/apt/config.py @@ -11,6 +11,20 @@ def __init__( position_embedding="learned", tokenizer=None, max_sequence_length = 1024, + query_zero_init = True, + n_layer = None, + contact_prediction_head = False, + initializer_range = 0.02, + # whether to use MuParametrization + use_mup = False, + # whether to initialize the output (readout) layer with zero-initialization + readout_zero_init = True, + # the output layer multiplier if mup is used, see https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L56 + mup_output_mult = 1.0, + width_mult_for_weights = 2.0, + # rope + rope_theta = 0.0, + rope_scaling_factor=1, **kwargs ): super().__init__(**kwargs) @@ -18,4 +32,15 @@ def __init__( self.position_embedding = position_embedding self.tokenizer = tokenizer self.max_sequence_length = max_sequence_length + + self.use_mup = use_mup + self.query_zero_init = query_zero_init, + self.n_layer = n_layer + self.contact_prediction_head = contact_prediction_head + self.initializer_range = initializer_range + self.readout_zero_init = readout_zero_init + self.mup_output_mult = mup_output_mult + self.width_mult_for_weights = width_mult_for_weights + self.rope_theta = rope_theta + self.rope_scaling_factor = rope_scaling_factor diff --git a/protein_lm/modeling/models/apt/model_pytorch.py b/protein_lm/modeling/models/apt/model_pytorch.py index f814519..6634b68 100644 --- a/protein_lm/modeling/models/apt/model_pytorch.py +++ b/protein_lm/modeling/models/apt/model_pytorch.py @@ -1,4 +1,5 @@ from typing import Optional, Tuple, Union +import math import torch from torch import nn from torch.nn import CrossEntropyLoss @@ -8,6 +9,7 @@ from transformers.pytorch_utils import Conv1D from transformers.activations import ACT2FN from transformers.utils import logging +from mup import MuReadout, MuSharedReadout, normal_ from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor @@ -16,6 +18,7 @@ logger = logging.get_logger(__name__) + class APTAttention(GPT2Attention): def __init__(self, config, is_cross_attention=False, layer_idx=None): super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx) @@ -42,6 +45,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): f" {self.num_heads})." ) + # muP + self.use_mup = config.use_mup + self.attn_score = nn.Identity() # just for coordcheck + self.query = nn.Identity() # just for coordcheck + self.key = nn.Identity() # just for coordcheck + self.value = nn.Identity() # just for coordcheck + self.scale_attn_weights = config.scale_attn_weights self.is_cross_attention = is_cross_attention @@ -55,13 +65,20 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.q_attn = Conv1D(self.embed_dim, self.embed_dim) else: self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim) + self.c_proj = Conv1D(self.embed_dim, self.embed_dim) - self.attn_dropout = nn.Dropout(config.attn_pdrop) - self.resid_dropout = nn.Dropout(config.resid_pdrop) + if self.use_mup: + self.attn_dropout = nn.Identity() + self.resid_dropout = nn.Identity() + else: + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) self.pruned_heads = set() + + self.rot_emb=None if self.position_embedding == "rope": self.rot_emb=RotaryEmbedding(dim=self.head_dim) @@ -72,15 +89,23 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): elif self.position_embedding=="dynamic_rope_scaling": self.rot_emb=LlamaDynamicNTKScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta) - + def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None): attn_weights = torch.matmul(query, key.transpose(-1, -2)) + #muP if self.scale_attn_weights: - attn_weights = attn_weights / torch.full( - [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device - ) + if self.use_mup: + attn_weights = attn_weights / torch.full( + [], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device + ) + else: + attn_weights = attn_weights / torch.full( + [], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device + ) + + attn_weights = self.attn_score(attn_weights) # Layer-wise attention scaling if self.scale_attn_by_inverse_layer_idx: @@ -97,7 +122,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) if alibi_bias is not None: attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] - + if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask @@ -150,7 +175,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea if alibi_bias is not None: attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] - + if attention_mask is not None: # Apply the attention mask attn_weights = attn_weights + attention_mask @@ -171,7 +196,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea return attn_output, attn_weights - + def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], @@ -202,11 +227,15 @@ def forward( query = self._split_heads(query, self.num_heads, self.head_dim) key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) - + + query = self.query(query) + key = self.key(key) + value = self.value(value) + kv_seq_len=key.shape[-2] if layer_past is not None: kv_seq_len+=layer_past[0].shape[-2] - + # Apply rope embedding to query and key if self.rot_emb: bsz, q_len, _ = hidden_states.size() @@ -225,7 +254,6 @@ def forward( key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) - if use_cache is True: present = (key, value) else: @@ -251,10 +279,20 @@ class APTMLP(nn.Module): def __init__(self, intermediate_size, config): super().__init__() embed_dim = config.hidden_size + + #muP + use_mup = config.use_mup + self.c_fc = Conv1D(intermediate_size, embed_dim) + self.c_proj = Conv1D(embed_dim, intermediate_size) + self.act = ACT2FN[config.activation_function] - self.dropout = nn.Dropout(config.resid_pdrop) + + if use_mup: + self.dropout = nn.Identity() + else: + self.dropout = nn.Dropout(config.resid_pdrop) def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor: hidden_states = self.c_fc(hidden_states) @@ -270,6 +308,9 @@ def __init__(self, config, layer_idx=None): hidden_size = config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size + #muP + self.use_mup = config.use_mup + self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.attn = APTAttention(config, layer_idx=layer_idx) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) @@ -354,23 +395,32 @@ def __init__(self, config): super().__init__(config) self.embed_dim = config.hidden_size + use_mup = config.use_mup self.wte = nn.Embedding(config.vocab_size, self.embed_dim) + self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned" - + if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.alibi = None elif self.position_embedding=="alibi": + #muP TO DO: check proper behavior in alibi case maxpos = config.n_positions attn_heads = config.n_head alibi = create_alibi_tensor(attn_heads,maxpos) self.register_buffer('alibi',alibi) else: raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi') - - self.drop = nn.Dropout(config.embd_pdrop) + + #muP + if use_mup: + self.drop = nn.Identity() + else: + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) + self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) # Model parallel @@ -477,7 +527,7 @@ def forward( hidden_states = inputs_embeds + position_embeds else: hidden_states = inputs_embeds - + if token_type_ids is not None: token_type_embeds = self.wte(token_type_ids) @@ -593,19 +643,78 @@ class APTLMHeadModel(GPT2PreTrainedModel): def __init__(self, config): super().__init__(config) self.transformer = APTModel(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) + + # muP + # TO DO: look into weight tying + # TO DO: if weight tying is used, APTMuSharedReadout with the proper tied weight should be used instead + self.lm_head = MuReadout(config.n_embd, + config.vocab_size, + bias=False, + readout_zero_init=config.readout_zero_init, + output_mult=config.output_mult) + + # mup + # note that this has to be run after mup.set_base_shape for it to work + # see https://github.com/microsoft/mup#basic-usage + # not sure if this is required here + self.apply(self._init_weights) # Model parallel self.model_parallel = False self.device_map = None - self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads, - prepend_bos=True, - append_eos=True, - eos_idx=2) + # mup implementation does not currently support this + if config.contact_prediction_head: + self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads, + prepend_bos=True, + append_eos=True, + eos_idx=2) # Initialize weights and apply final processing self.post_init() + + # mup + # general function for mup-specific weight initialization + def _init_weights(self, module): + if isinstance(module, (MuReadout, MuSharedReadout)) and self.config.readout_zero_init: + module.weight.data.zero_() + elif isinstance(module, (nn.Linear, Conv1D)): + if hasattr(module.weight, 'infshape'): + normal_(module.weight, mean=0.0, std=self.config.initializer_range) + else: + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + if isinstance(module, APTAttention): + if hasattr(module, "q_attn"): + # cross attention case + if self.config.query_zero_init: + # q_attn same as last third of c_attn in no cross attention case -- zero initialization + self.q_attn.weight.data = 0 + else: + if self.config.query_zero_init: + _, fanout = module.c_attn.weight.shape + assert fanout % 3 == 0 + module.c_attn.weight.data[:, :fanout//3] = 0 + + depth_std = self.config.initializer_range / math.sqrt(2 * self.config.n_layer) + for name, p in module.named_parameters(): + if "c_proj" in name and "weight" in name: + # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block + if hasattr(p, 'infshape'): + normal_(p, mean=0.0, std=depth_std) + else: + p.data.normal_(mean=0.0, std=depth_std) + + def forward( self, input_ids: Optional[torch.LongTensor] = None, diff --git a/protein_lm/modeling/scripts/train.py b/protein_lm/modeling/scripts/train.py index 7c2d555..e993f2c 100644 --- a/protein_lm/modeling/scripts/train.py +++ b/protein_lm/modeling/scripts/train.py @@ -53,6 +53,8 @@ def train( config_dict["wandb"], ) + # TO DO: add support for mup's optimizers in case use_mup is used, see e.g. https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/optim.py + # available via mup.optim trainer = Trainer( model=model, args=training_args, diff --git a/protein_lm/tests/test_coord_check.py b/protein_lm/tests/test_coord_check.py new file mode 100644 index 0000000..e463d5a --- /dev/null +++ b/protein_lm/tests/test_coord_check.py @@ -0,0 +1,37 @@ +from mup import make_base_shapes, set_base_shapes, make_base_shapes, set_base_shapes, get_shapes, MuAdam, MuSGD, MuAdamW +from mup.coord_check import get_coord_data, plot_coord_data +from functools import partial +import torch +from protein_lm.modeling import APTConfig, APTLMHeadModel + + +# not sure how to leverage pytest in the context of coordinate checking +# this is because visual inspection of coordinate checking results is necessary +# the test will generate coordinate checking results to a test_results directory for now + +if __name__ == "__main__": + delta_model = APTLMHeadModel(config=APTConfig(n_embd=200, n_layer=8, num_attention_heads=10, n_inner=200, use_mup=True)) + delta_model.apply(delta_model._init_weights) + + base_model = APTLMHeadModel(config=APTConfig(n_embd=1, n_layer=8, num_attention_heads=1, n_inner=1, use_mup=True)) + base_model.apply(base_model._init_weights) + + def get_mup_apt_model(width): + model = APTLMHeadModel(config=APTConfig(n_embd=width, n_layer=8, num_attention_heads=width//16, n_inner=width, use_mup=True)) + return model + + def set_up_mup_apt_model(width): + model = set_base_shapes(get_mup_apt_model(width), base_model, delta=delta_model) + model.apply(model._init_weights) + return model + + def get_mup_lazy_model(width): + return lambda: set_up_mup_apt_model(width) + + models = {256: get_mup_lazy_model(256), 512: get_mup_lazy_model(512), 1024: get_mup_lazy_model(1024), 2048: get_mup_lazy_model(2048)} + + input_ids = torch.randint(low=0, high=50257, size=(1, 256)).to(torch.int64) + labels = torch.randint(low=0, high=50257, size=(1, 256)).to(torch.int64) + dataloader=[{'input_ids': input_ids, 'labels': labels}] + df = get_coord_data(models, dataloader, optimizer='sgd', lr=0.1, dict_in_out=True, output_name='loss', cuda=True, nsteps=10, nseeds=10) + plot_coord_data(df, legend=None, save_to='test_results/apt_coordcheck.jpg') \ No newline at end of file diff --git a/protein_lm/tests/test_optuna_mup.py b/protein_lm/tests/test_optuna_mup.py new file mode 100644 index 0000000..8634a44 --- /dev/null +++ b/protein_lm/tests/test_optuna_mup.py @@ -0,0 +1,59 @@ +import optuna +from mup import make_base_shapes, set_base_shapes, make_base_shapes, set_base_shapes, get_shapes, MuAdam, MuSGD, MuAdamW +from mup.coord_check import get_coord_data, plot_coord_data +from functools import partial +import torch +from protein_lm.modeling import APTConfig, APTLMHeadModel + +def get_mup_apt_model(width): + model = APTLMHeadModel(config=APTConfig(n_embd=width, n_layer=8, num_attention_heads=width//16, n_inner=width, use_mup=True)) + return model + +def set_up_mup_apt_model(width): + model = set_base_shapes(get_mup_apt_model(width), base_model, delta=delta_model) + model.apply(model._init_weights) + return model + +def get_mup_lazy_model(width): + return lambda: set_up_mup_apt_model(width) + +def objective(trial): + # Suggest a model width (n_embd) as a hyperparameter + width = trial.suggest_categorical('width', [256, 512, 1024, 2048]) + + # Configure and instantiate models based on the trial's suggested width + model_config = get_mup_lazy_model(width)() + + # Here, you can insert your training and evaluation logic + # For demonstration, we're using a simplified version of your existing code to get coordinate data + input_ids = torch.randint(low=0, high=50257, size=(1, 256)).to(torch.int64) + labels = torch.randint(low=0, high=50257, size=(1, 256)).to(torch.int64) + dataloader = [{'input_ids': input_ids, 'labels': labels}] + + # Assuming 'get_coord_data' returns a DataFrame with a 'loss' column + df = get_coord_data({width: lambda: model_config}, dataloader, optimizer='sgd', lr=0.1, dict_in_out=True, output_name='loss', cuda=True, nsteps=10, nseeds=1) + + # Here, we assume the DataFrame contains the loss values and we return the average loss + # Adjust this based on how your actual loss values are calculated and returned + avg_loss = df['loss'].mean() + + return avg_loss + + + +if __name__ == "__main__": + delta_model = APTLMHeadModel(config=APTConfig(n_embd=200, n_layer=8, num_attention_heads=10, n_inner=200, use_mup=True)) + delta_model.apply(delta_model._init_weights) + + base_model = APTLMHeadModel(config=APTConfig(n_embd=1, n_layer=8, num_attention_heads=1, n_inner=1, use_mup=True)) + base_model.apply(base_model._init_weights) + + + + models = {256: get_mup_lazy_model(256), 512: get_mup_lazy_model(512), 1024: get_mup_lazy_model(1024), 2048: get_mup_lazy_model(2048)} + + input_ids = torch.randint(low=0, high=50257, size=(1, 256)).to(torch.int64) + labels = torch.randint(low=0, high=50257, size=(1, 256)).to(torch.int64) + dataloader=[{'input_ids': input_ids, 'labels': labels}] + df = get_coord_data(models, dataloader, optimizer='sgd', lr=0.1, dict_in_out=True, output_name='loss', cuda=True, nsteps=10, nseeds=10) + plot_coord_data(df, legend=None, save_to='test_results/apt_coordcheck.jpg') \ No newline at end of file diff --git a/protein_lm_cuda.yml b/protein_lm_cuda.yml index 3ce6cb4..c6511c8 100644 --- a/protein_lm_cuda.yml +++ b/protein_lm_cuda.yml @@ -20,3 +20,4 @@ dependencies: - evaluate - pytest - fair-esm + - mup \ No newline at end of file