Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optuna super simple implementation #67

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
Empty file.
Empty file.
1 change: 1 addition & 0 deletions build/lib/protein_lm/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from models import *
1 change: 1 addition & 0 deletions protein_lm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ dependencies:
- evaluate
- pytest
- fair-esm
- mup
1 change: 1 addition & 0 deletions protein_lm/modeling/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from models import *
1 change: 1 addition & 0 deletions protein_lm/modeling/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from apt import *
2 changes: 2 additions & 0 deletions protein_lm/modeling/models/apt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from config import APTConfig
from model_pytorch import *
25 changes: 25 additions & 0 deletions protein_lm/modeling/models/apt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,36 @@ def __init__(
position_embedding="learned",
tokenizer=None,
max_sequence_length = 1024,
query_zero_init = True,
n_layer = None,
contact_prediction_head = False,
initializer_range = 0.02,
# whether to use MuParametrization
use_mup = False,
# whether to initialize the output (readout) layer with zero-initialization
readout_zero_init = True,
# the output layer multiplier if mup is used, see https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/layer.py#L56
mup_output_mult = 1.0,
width_mult_for_weights = 2.0,
# rope
rope_theta = 0.0,
rope_scaling_factor=1,
**kwargs
):
super().__init__(**kwargs)
self.nn_model_type = "APT"
self.position_embedding = position_embedding
self.tokenizer = tokenizer
self.max_sequence_length = max_sequence_length

self.use_mup = use_mup
self.query_zero_init = query_zero_init,
self.n_layer = n_layer
self.contact_prediction_head = contact_prediction_head
self.initializer_range = initializer_range
self.readout_zero_init = readout_zero_init
self.mup_output_mult = mup_output_mult
self.width_mult_for_weights = width_mult_for_weights
self.rope_theta = rope_theta
self.rope_scaling_factor = rope_scaling_factor

153 changes: 131 additions & 22 deletions protein_lm/modeling/models/apt/model_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, Tuple, Union
import math
import torch
from torch import nn
from torch.nn import CrossEntropyLoss
Expand All @@ -8,6 +9,7 @@
from transformers.pytorch_utils import Conv1D
from transformers.activations import ACT2FN
from transformers.utils import logging
from mup import MuReadout, MuSharedReadout, normal_
from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding
from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding
from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor
Expand All @@ -16,6 +18,7 @@

logger = logging.get_logger(__name__)


class APTAttention(GPT2Attention):
def __init__(self, config, is_cross_attention=False, layer_idx=None):
super().__init__(config, is_cross_attention=is_cross_attention, layer_idx=layer_idx)
Expand All @@ -42,6 +45,13 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
f" {self.num_heads})."
)

# muP
self.use_mup = config.use_mup
self.attn_score = nn.Identity() # just for coordcheck
self.query = nn.Identity() # just for coordcheck
self.key = nn.Identity() # just for coordcheck
self.value = nn.Identity() # just for coordcheck

self.scale_attn_weights = config.scale_attn_weights
self.is_cross_attention = is_cross_attention

Expand All @@ -55,13 +65,20 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.q_attn = Conv1D(self.embed_dim, self.embed_dim)
else:
self.c_attn = Conv1D(3 * self.embed_dim, self.embed_dim)

self.c_proj = Conv1D(self.embed_dim, self.embed_dim)

self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)
if self.use_mup:
self.attn_dropout = nn.Identity()
self.resid_dropout = nn.Identity()
else:
self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

self.pruned_heads = set()



self.rot_emb=None
if self.position_embedding == "rope":
self.rot_emb=RotaryEmbedding(dim=self.head_dim)
Expand All @@ -72,15 +89,23 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
elif self.position_embedding=="dynamic_rope_scaling":
self.rot_emb=LlamaDynamicNTKScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta)



def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None):
attn_weights = torch.matmul(query, key.transpose(-1, -2))

#muP
if self.scale_attn_weights:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)
if self.use_mup:
attn_weights = attn_weights / torch.full(
[], value.size(-1), dtype=attn_weights.dtype, device=attn_weights.device
)
else:
attn_weights = attn_weights / torch.full(
[], value.size(-1) ** 0.5, dtype=attn_weights.dtype, device=attn_weights.device
)

attn_weights = self.attn_score(attn_weights)

# Layer-wise attention scaling
if self.scale_attn_by_inverse_layer_idx:
Expand All @@ -97,7 +122,7 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
if alibi_bias is not None:
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]

if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
Expand Down Expand Up @@ -150,7 +175,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea

if alibi_bias is not None:
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]

if attention_mask is not None:
# Apply the attention mask
attn_weights = attn_weights + attention_mask
Expand All @@ -171,7 +196,7 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea

return attn_output, attn_weights


def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
Expand Down Expand Up @@ -202,11 +227,15 @@ def forward(
query = self._split_heads(query, self.num_heads, self.head_dim)
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)


query = self.query(query)
key = self.key(key)
value = self.value(value)

kv_seq_len=key.shape[-2]
if layer_past is not None:
kv_seq_len+=layer_past[0].shape[-2]

# Apply rope embedding to query and key
if self.rot_emb:
bsz, q_len, _ = hidden_states.size()
Expand All @@ -225,7 +254,6 @@ def forward(
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)


if use_cache is True:
present = (key, value)
else:
Expand All @@ -251,10 +279,20 @@ class APTMLP(nn.Module):
def __init__(self, intermediate_size, config):
super().__init__()
embed_dim = config.hidden_size

#muP
use_mup = config.use_mup

self.c_fc = Conv1D(intermediate_size, embed_dim)

self.c_proj = Conv1D(embed_dim, intermediate_size)

self.act = ACT2FN[config.activation_function]
self.dropout = nn.Dropout(config.resid_pdrop)

if use_mup:
self.dropout = nn.Identity()
else:
self.dropout = nn.Dropout(config.resid_pdrop)

def forward(self, hidden_states: Optional[Tuple[torch.FloatTensor]]) -> torch.FloatTensor:
hidden_states = self.c_fc(hidden_states)
Expand All @@ -270,6 +308,9 @@ def __init__(self, config, layer_idx=None):
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size

#muP
self.use_mup = config.use_mup

self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = APTAttention(config, layer_idx=layer_idx)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
Expand Down Expand Up @@ -354,23 +395,32 @@ def __init__(self, config):
super().__init__(config)

self.embed_dim = config.hidden_size
use_mup = config.use_mup

self.wte = nn.Embedding(config.vocab_size, self.embed_dim)

self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned"

if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling":
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.alibi = None
elif self.position_embedding=="alibi":
#muP TO DO: check proper behavior in alibi case
maxpos = config.n_positions
attn_heads = config.n_head
alibi = create_alibi_tensor(attn_heads,maxpos)
self.register_buffer('alibi',alibi)
else:
raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi')

self.drop = nn.Dropout(config.embd_pdrop)

#muP
if use_mup:
self.drop = nn.Identity()
else:
self.drop = nn.Dropout(config.embd_pdrop)

self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])

self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)

# Model parallel
Expand Down Expand Up @@ -477,7 +527,7 @@ def forward(
hidden_states = inputs_embeds + position_embeds
else:
hidden_states = inputs_embeds


if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
Expand Down Expand Up @@ -593,19 +643,78 @@ class APTLMHeadModel(GPT2PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.transformer = APTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)

# muP
# TO DO: look into weight tying
# TO DO: if weight tying is used, APTMuSharedReadout with the proper tied weight should be used instead
self.lm_head = MuReadout(config.n_embd,
config.vocab_size,
bias=False,
readout_zero_init=config.readout_zero_init,
output_mult=config.output_mult)

# mup
# note that this has to be run after mup.set_base_shape for it to work
# see https://github.com/microsoft/mup#basic-usage
# not sure if this is required here
self.apply(self._init_weights)

# Model parallel
self.model_parallel = False
self.device_map = None

self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads,
prepend_bos=True,
append_eos=True,
eos_idx=2)
# mup implementation does not currently support this
if config.contact_prediction_head:
self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads,
prepend_bos=True,
append_eos=True,
eos_idx=2)

# Initialize weights and apply final processing
self.post_init()

# mup
# general function for mup-specific weight initialization
def _init_weights(self, module):
if isinstance(module, (MuReadout, MuSharedReadout)) and self.config.readout_zero_init:
module.weight.data.zero_()
elif isinstance(module, (nn.Linear, Conv1D)):
if hasattr(module.weight, 'infshape'):
normal_(module.weight, mean=0.0, std=self.config.initializer_range)
else:
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.LayerNorm):
module.bias.data.zero_()
module.weight.data.fill_(1.0)

if isinstance(module, APTAttention):
if hasattr(module, "q_attn"):
# cross attention case
if self.config.query_zero_init:
# q_attn same as last third of c_attn in no cross attention case -- zero initialization
self.q_attn.weight.data = 0
else:
if self.config.query_zero_init:
_, fanout = module.c_attn.weight.shape
assert fanout % 3 == 0
module.c_attn.weight.data[:, :fanout//3] = 0

depth_std = self.config.initializer_range / math.sqrt(2 * self.config.n_layer)
for name, p in module.named_parameters():
if "c_proj" in name and "weight" in name:
# Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
if hasattr(p, 'infshape'):
normal_(p, mean=0.0, std=depth_std)
else:
p.data.normal_(mean=0.0, std=depth_std)


def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down
2 changes: 2 additions & 0 deletions protein_lm/modeling/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def train(
config_dict["wandb"],
)

# TO DO: add support for mup's optimizers in case use_mup is used, see e.g. https://github.com/microsoft/mup/blob/19814971934ef91dd546f88e913fc963e096d11c/mup/optim.py
# available via mup.optim
trainer = Trainer(
model=model,
args=training_args,
Expand Down
Loading