From 587b05ea8c36cb79290e3d832d036ef03b108d27 Mon Sep 17 00:00:00 2001 From: atenzer Date: Wed, 9 Feb 2022 11:48:26 +0800 Subject: [PATCH] [#42] Modeling refactoring and skeleton for usage. --- .../coupled_hierarchical_transformer/usage.py | 39 ++++ .../__init__.py | 2 + .../config.py | 20 ++ .../modeling.py | 4 +- .../new_modeling.py | 174 ++++++++++++++++++ .../coupled_hierarchical_transformer/train.py | 37 +++- .../train_config_local.json | 2 +- 7 files changed, 267 insertions(+), 11 deletions(-) create mode 100644 demo_api/coupled_hierarchical_transformer/usage.py create mode 100644 sgnlp/models/coupled_hierarchical_transformer/new_modeling.py diff --git a/demo_api/coupled_hierarchical_transformer/usage.py b/demo_api/coupled_hierarchical_transformer/usage.py new file mode 100644 index 0000000..e1ff78d --- /dev/null +++ b/demo_api/coupled_hierarchical_transformer/usage.py @@ -0,0 +1,39 @@ +import torch +from transformers import BertConfig + +from sgnlp.models.coupled_hierarchical_transformer import ( + DualBert, + prepare_data_for_training +) + +# model_state_dict = torch.load("/Users/nus/Documents/Code/projects/SGnlp/sgnlp/output/pytorch_model.bin") +# model = DualBert.from_pretrained( +# "bert-base-uncased", +# state_dict=model_state_dict, +# rumor_num_labels=3, +# stance_num_labels=5, +# max_tweet_num=17, +# max_tweet_length=30, +# convert_size=20, +# ) +# +# print("x") + +preprocessor = DualBertPreprocessor() + +config = DualBertConfig.from_pretrained("path to config") +model = DualBert.from_pretrained("/Users/nus/Documents/Code/projects/SGnlp/sgnlp/output/pytorch_model.bin", config=config) + +model.eval() + +example = [ + "Claim", + "Response 1", + "Response 2" +] + +model_inputs = preprocessor([example]) +# { model_param_1: ..., model_param2: ..., ...} + +model(**model_inputs) + diff --git a/sgnlp/models/coupled_hierarchical_transformer/__init__.py b/sgnlp/models/coupled_hierarchical_transformer/__init__.py index e69de29..130d92d 100644 --- a/sgnlp/models/coupled_hierarchical_transformer/__init__.py +++ b/sgnlp/models/coupled_hierarchical_transformer/__init__.py @@ -0,0 +1,2 @@ +from .modeling import DualBert +from .preprocess import prepare_data_for_training \ No newline at end of file diff --git a/sgnlp/models/coupled_hierarchical_transformer/config.py b/sgnlp/models/coupled_hierarchical_transformer/config.py index e69de29..d08b3f7 100644 --- a/sgnlp/models/coupled_hierarchical_transformer/config.py +++ b/sgnlp/models/coupled_hierarchical_transformer/config.py @@ -0,0 +1,20 @@ +from transformers import BertConfig + + +class DualBertConfig(BertConfig): + + def __init__(self, rumor_num_labels=2, stance_num_labels=2, max_tweet_num=17, max_tweet_length=30, convert_size=20, + vocab_size=30522, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, + intermediate_size=3072, hidden_act="gelu", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, + max_position_embeddings=512, type_vocab_size=2, initializer_range=0.02, layer_norm_eps=1e-12, + pad_token_id=0, position_embedding_type="absolute", use_cache=True, classifier_dropout=None, **kwargs): + super().__init__(vocab_size, hidden_size, num_hidden_layers, num_attention_heads, intermediate_size, hidden_act, + hidden_dropout_prob, attention_probs_dropout_prob, max_position_embeddings, type_vocab_size, + initializer_range, layer_norm_eps, pad_token_id, position_embedding_type, use_cache, + classifier_dropout, **kwargs) + + self.rumor_num_labels = rumor_num_labels + self.stance_num_labels = stance_num_labels + self.max_tweet_num = max_tweet_num + self.max_tweet_length = max_tweet_length + self.convert_size = convert_size diff --git a/sgnlp/models/coupled_hierarchical_transformer/modeling.py b/sgnlp/models/coupled_hierarchical_transformer/modeling.py index 6092372..8ce9d75 100644 --- a/sgnlp/models/coupled_hierarchical_transformer/modeling.py +++ b/sgnlp/models/coupled_hierarchical_transformer/modeling.py @@ -27,6 +27,8 @@ import torch from torch import nn from torch.nn import CrossEntropyLoss +from transformers import PreTrainedModel + from .utils import cached_path @@ -752,7 +754,7 @@ def forward(self, sequence_output, pooled_output): return prediction_scores, seq_relationship_score -class PreTrainedBertModel(nn.Module): +class PreTrainedBertModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for dowloading and loading pretrained models. """ diff --git a/sgnlp/models/coupled_hierarchical_transformer/new_modeling.py b/sgnlp/models/coupled_hierarchical_transformer/new_modeling.py new file mode 100644 index 0000000..88d4cca --- /dev/null +++ b/sgnlp/models/coupled_hierarchical_transformer/new_modeling.py @@ -0,0 +1,174 @@ +import torch +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import BertModel, PreTrainedModel, BertPreTrainedModel +from transformers.models.bert.modeling_bert import BertPooler + +from sgnlp.models.coupled_hierarchical_transformer.config import DualBertConfig +from sgnlp.models.coupled_hierarchical_transformer.modeling import BertCrossEncoder, ADDBertReturnEncoder, \ + MTBertStancePooler, BertPooler_v2, BertSelfLabelAttention + + +class DualBertPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = DualBertConfig + base_model_prefix = "dual_bert" + + def _init_weights(self, module): + """Initialize the weights""" + if isinstance(module, nn.Linear): + # Slightly different from the TF version which uses truncated_normal for initialization + # cf https://github.com/pytorch/pytorch/pull/5617 + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class DualBert(DualBertPreTrainedModel): + def __init__(self, config, + # rumor_num_labels=2, stance_num_labels=2, max_tweet_num=17, max_tweet_length=30, + # convert_size=20 + ): + super(DualBert, self).__init__(config) + self.rumor_num_labels = config.rumor_num_labels + self.stance_num_labels = config.stance_num_labels + self.bert = BertModel(config) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + self.add_rumor_bert_attention = BertCrossEncoder(config) + self.add_stance_bert_attention = ADDBertReturnEncoder(config) + self.max_tweet_num = config.max_tweet_num + self.max_tweet_length = config.max_tweet_length + self.stance_pooler = MTBertStancePooler(config) + # previous version + # self.rumor_pooler = BertPooler(config) + # self.add_self_attention = BertSelfLabelAttention(config, stance_num_labels) + # self.rumor_classifier = nn.Linear(config.hidden_size+stance_num_labels, rumor_num_labels) + # new version + # self.rumor_pooler = BertPooler_v2(config.hidden_size+stance_num_labels) # +stance_num_labels + # self.add_self_attention = BertSelfLabelAttention(config, config.hidden_size+stance_num_labels) + # self.rumor_classifier = nn.Linear(config.hidden_size+stance_num_labels, rumor_num_labels) + # Version 3 + # self.rumor_pooler = BertPooler(config) + # self.add_self_attention = BertSelfLabelAttention(config, config.hidden_size+stance_num_labels) + # self.rumor_classifier = nn.Linear(config.hidden_size*2+stance_num_labels, rumor_num_labels) + # Version 4 + self.convert_size = config.convert_size # 100 pheme seed 42, 100->0.423, 0.509, 75 OK, 32, 50, 64, 80, 90, 120, 128, 200 not good, + self.rumor_pooler = BertPooler(config) + self.hybrid_rumor_pooler = BertPooler_v2(config.hidden_size + config.stance_num_labels) + self.add_self_attention = BertSelfLabelAttention(config, config.hidden_size + config.stance_num_labels) + self.linear_conversion = nn.Linear(config.hidden_size + config.stance_num_labels, self.convert_size) + self.rumor_classifier = nn.Linear(config.hidden_size + self.convert_size, config.rumor_num_labels) + #### self.rumor_classifier = nn.Linear(config.hidden_size, rumor_num_labels) + self.stance_classifier = nn.Linear(config.hidden_size, config.stance_num_labels) + #### self.cos_sim = nn.CosineSimilarity(dim=-1, eps=1e-6) + # self.apply(self.init_bert_weights) + self.init_weights() + + def forward(self, input_ids1, token_type_ids1, attention_mask1, input_ids2, token_type_ids2, attention_mask2, + input_ids3, token_type_ids3, attention_mask3, input_ids4, token_type_ids4, attention_mask4, + attention_mask, rumor_labels=None, task=None, stance_labels=None, stance_label_mask=None): + + output1 = self.bert(input_ids1, token_type_ids1, attention_mask1, output_hidden_states=False) + output2 = self.bert(input_ids2, token_type_ids2, attention_mask2, output_hidden_states=False) + output3 = self.bert(input_ids3, token_type_ids3, attention_mask3, output_hidden_states=False) + output4 = self.bert(input_ids4, token_type_ids4, attention_mask4, output_hidden_states=False) + + sequence_output1 = output1.last_hidden_state + sequence_output2 = output2.last_hidden_state + sequence_output3 = output3.last_hidden_state + sequence_output4 = output4.last_hidden_state + + tmp_sequence = torch.cat((sequence_output1, sequence_output2), dim=1) + tmp_sequence = torch.cat((tmp_sequence, sequence_output3), dim=1) + sequence_output = torch.cat((tmp_sequence, sequence_output4), dim=1) + + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + extended_attention_mask = extended_attention_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + # for stance classification task + # ''' + # ##add_output_layer = self.add_self_attention(sequence_output, extended_attention_mask) + add_stance_bert_encoder, stance_attention_probs = self.add_stance_bert_attention(sequence_output, + extended_attention_mask) + final_stance_text_output = add_stance_bert_encoder[-1] + stance_attention = stance_attention_probs[-1] + label_logit_output = self.stance_pooler(final_stance_text_output, self.max_tweet_num, self.max_tweet_length) + sequence_stance_output = self.dropout(label_logit_output) + stance_logits = self.stance_classifier(sequence_stance_output) + # ''' + + if task is None: # for rumor detection task + # ''' + add_rumor_bert_encoder, rumor_attention_probs = self.add_rumor_bert_attention(final_stance_text_output, + sequence_output, + extended_attention_mask) + add_rumor_bert_text_output_layer = add_rumor_bert_encoder[-1] + rumor_attention = rumor_attention_probs[-1] + + # ''' add label attention layer to incorporate stance predictions for rumor verification + extended_label_mask = stance_label_mask.unsqueeze(1).unsqueeze(2) + extended_label_mask = extended_label_mask.to( + dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_label_mask = (1.0 - extended_label_mask) * -10000.0 + + rumor_output = self.rumor_pooler(add_rumor_bert_text_output_layer) + tweet_level_output = self.stance_pooler(add_rumor_bert_text_output_layer, self.max_tweet_num, + self.max_tweet_length) + final_rumor_output = torch.cat((tweet_level_output, stance_logits), dim=-1) # stance_logits + combined_layer, attention_probs = self.add_self_attention(final_rumor_output, extended_label_mask) + hybrid_rumor_stance_output = self.hybrid_rumor_pooler(combined_layer) + hybrid_conversion_output = self.linear_conversion(hybrid_rumor_stance_output) + final_rumor_text_output = torch.cat((rumor_output, hybrid_conversion_output), dim=-1) + rumor_pooled_output = self.dropout(final_rumor_text_output) + logits = self.rumor_classifier(rumor_pooled_output) + # ''' + + if rumor_labels is not None: + # alpha = 0.1 + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.rumor_num_labels), rumor_labels.view(-1)) + # sim_loss = self.cos_sim(stance_attention, rumor_attention) + # return loss + alpha*sim_loss + return loss + else: + # return logits + return logits, attention_probs[:, 0, 0, :] + # fisrt 0 denotes head, second 0 denotes the first position's attention over all the tweets + else: + # for stance classification task + + # label_logit_output = self.stance_pooler(sequence_output) + ''' + label_logit_output = self.stance_pooler(final_stance_text_output) + sequence_stance_output = self.dropout(label_logit_output) + stance_logits = self.stance_classifier(sequence_stance_output) + ''' + + if stance_labels is not None: # for stance classification task + loss_fct = CrossEntropyLoss() + # Only keep active parts of the loss + if stance_label_mask is not None: + active_loss = stance_label_mask.view(-1) == 1 + # print(active_loss) + # print(logits) + active_logits = stance_logits.view(-1, self.stance_num_labels)[active_loss] + active_labels = stance_labels.view(-1)[active_loss] + loss = loss_fct(active_logits, active_labels) + else: + loss = loss_fct(stance_logits.view(-1, self.stance_num_labels), stance_labels.view(-1)) + return loss + else: + return stance_logits \ No newline at end of file diff --git a/sgnlp/models/coupled_hierarchical_transformer/train.py b/sgnlp/models/coupled_hierarchical_transformer/train.py index d515d6b..efdeb40 100644 --- a/sgnlp/models/coupled_hierarchical_transformer/train.py +++ b/sgnlp/models/coupled_hierarchical_transformer/train.py @@ -27,10 +27,12 @@ from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler import torch.nn.functional as F + +from sgnlp.models.coupled_hierarchical_transformer.config import DualBertConfig from .utils import classification_report from transformers import BertTokenizer -from .modeling import DualBert +from .new_modeling import DualBert from transformers import AdamW from .utils import PYTORCH_PRETRAINED_BERT_CACHE from .preprocess import prepare_data_for_training, convert_examples_to_features @@ -316,16 +318,25 @@ def train_custom_cht(train_config: CustomCoupledHierarchicalTransformerTrainConf # Prepare model print("The current multi-task learning model is our Dual Bert model...") - model = DualBert.from_pretrained( - train_config.bert_model, - cache_dir=PYTORCH_PRETRAINED_BERT_CACHE - / "distributed_{}".format(train_config.local_rank), + config = DualBertConfig( rumor_num_labels=train_config.rumor_num_labels, stance_num_labels=train_config.stance_num_labels, max_tweet_num=train_config.max_tweet_num, max_tweet_length=train_config.max_tweet_length, convert_size=train_config.convert_size, ) + model = DualBert(config) + + # model = DualBert.from_pretrained( + # train_config.bert_model, + # # cache_dir=PYTORCH_PRETRAINED_BERT_CACHE + # # / "distributed_{}".format(train_config.local_rank), + # # rumor_num_labels=train_config.rumor_num_labels, + # # stance_num_labels=train_config.stance_num_labels, + # # max_tweet_num=train_config.max_tweet_num, + # # max_tweet_length=train_config.max_tweet_length, + # # convert_size=train_config.convert_size, + # ) if train_config.fp16: model.half() @@ -735,7 +746,8 @@ def train_custom_cht(train_config: CustomCoupledHierarchicalTransformerTrainConf model.module if hasattr(model, "module") else model ) # Only save the model it-self if train_config.do_train: - torch.save(model_to_save.state_dict(), output_model_file) + # torch.save(model_to_save.state_dict(), output_model_file) + model.save_pretrained(train_config.output_dir) max_acc_f1 = eval_accuracy + F_score # max_acc_f1 = eval_accuracy+F_score+stance_F_score @@ -939,6 +951,13 @@ def train_custom_cht(train_config: CustomCoupledHierarchicalTransformerTrainConf model_state_dict = torch.load(output_model_file) + # config = DualBertConfig(rumor_num_labels=train_config.rumor_num_labels, + # stance_num_labels=train_config.stance_num_labels, + # max_tweet_num=train_config.max_tweet_num, + # max_tweet_length=train_config.max_tweet_length, + # convert_size=train_config.convert_size,) + # model = DualBert(config) + model = DualBert.from_pretrained( train_config.bert_model, state_dict=model_state_dict, @@ -1015,7 +1034,7 @@ def train_custom_cht(train_config: CustomCoupledHierarchicalTransformerTrainConf [f.input_mask for f in eval_features], dtype=torch.int32 ) all_label_ids = torch.tensor( - [f.label_id for f in eval_features], dtype=torch.int32 + [f.label_id for f in eval_features], dtype=torch.long ) all_label_mask = torch.tensor( [f.label_mask for f in eval_features], dtype=torch.int32 @@ -1393,6 +1412,6 @@ def train_custom_cht(train_config: CustomCoupledHierarchicalTransformerTrainConf if __name__ == "__main__": - #train_config = load_train_config("/Users/nus/Documents/Code/projects/SGnlp/sgnlp/sgnlp/models/coupled_hierarchical_transformer/train_config_local.json") - train_config = load_train_config("/polyaxon-data/workspace/atenzer/CHT_demo/train_config.json") + train_config = load_train_config("/Users/nus/Documents/Code/projects/SGnlp/sgnlp/sgnlp/models/coupled_hierarchical_transformer/train_config_local.json") + # train_config = load_train_config("/polyaxon-data/workspace/atenzer/CHT_demo/train_config.json") train_custom_cht(train_config) diff --git a/sgnlp/models/coupled_hierarchical_transformer/train_config_local.json b/sgnlp/models/coupled_hierarchical_transformer/train_config_local.json index ff03c0b..c7a850b 100644 --- a/sgnlp/models/coupled_hierarchical_transformer/train_config_local.json +++ b/sgnlp/models/coupled_hierarchical_transformer/train_config_local.json @@ -14,7 +14,7 @@ "train_batch_size": 1, "eval_batch_size": 1, "learning_rate": 5e-5, - "num_train_epochs": 30.0, + "num_train_epochs": 1, "warmup_proportion": 0.1, "no_cuda": false, "local_rank":-1,