Skip to content

Commit

Permalink
[#42] Update device uploads and previous optimiser for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
atenzer committed Mar 7, 2022
1 parent 20dfa05 commit 109dd40
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 22 deletions.
32 changes: 16 additions & 16 deletions sgnlp/models/dual_bert/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,29 +379,29 @@ def init_bert(self):
self.bert = BertModel.from_pretrained("bert-base-uncased")

def forward(self, input_ids_buckets, segment_ids_buckets, input_mask_buckets, input_mask, stance_position,
stance_label_mask, stance_label_ids=None, rumor_label_ids=None, rumor_labels=None, stance_labels=None):
stance_label_mask, stance_label_ids=None, rumor_label_ids=None):
# def forward(self, input_ids1, token_type_ids1, attention_mask1, input_ids2, token_type_ids2, attention_mask2,
# input_ids3, token_type_ids3, attention_mask3, input_ids4, token_type_ids4, attention_mask4,
# attention_mask, rumor_labels=None, stance_labels=None, stance_label_mask=None):
# attention_mask, rumor_labels=None, stance_label_ids=None, stance_label_mask=None):

output = DualBertModelOutput()

output_sequence = torch.tensor([], dtype=torch.int32)
output_sequence = torch.tensor([], device=self.device, dtype=torch.int32)
# for input_ids, token_type_ids, attention_mask in zip(processed_input["input_ids_buckets"], processed_input[
# "segment_ids_buckets"], processed_input["input_mask_buckets"]):
num_buckets = 4
for i in range(num_buckets):
input_ids = input_ids_buckets[:, i]
token_type_ids = segment_ids_buckets[:, i]
attention_mask = input_mask_buckets[:, i]
input_ids = input_ids_buckets[:, i].to(self.device)
token_type_ids = segment_ids_buckets[:, i].to(self.device)
attention_mask = input_mask_buckets[:, i].to(self.device)

tmp_model_output = self.bert(input_ids, token_type_ids, attention_mask, output_hidden_states=False)
output_sequence = torch.cat((output_sequence, tmp_model_output.last_hidden_state), dim=1)

extended_attention_mask = input_mask.unsqueeze(1).unsqueeze(2)
extended_attention_mask = extended_attention_mask.to(
dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
extended_attention_mask = extended_attention_mask.to(self.device)

# for stance classification task
# '''
Expand All @@ -423,9 +423,9 @@ def forward(self, input_ids_buckets, segment_ids_buckets, input_mask_buckets, in

# ''' add label attention layer to incorporate stance predictions for rumor verification
extended_label_mask = stance_label_mask.unsqueeze(1).unsqueeze(2)
extended_label_mask = extended_label_mask.to(
dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_label_mask = extended_label_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
extended_label_mask = (1.0 - extended_label_mask) * -10000.0
extended_label_mask = extended_label_mask.to(self.device)

rumor_output = self.rumor_pooler(add_rumor_bert_text_output_layer)
tweet_level_output = self.stance_pooler(add_rumor_bert_text_output_layer, self.max_tweet_num,
Expand All @@ -439,26 +439,26 @@ def forward(self, input_ids_buckets, segment_ids_buckets, input_mask_buckets, in
output.rumour_logits = self.rumor_classifier(rumor_pooled_output)
# '''

if rumor_labels is not None:
if rumor_label_ids is not None:
# alpha = 0.1
loss_fct = CrossEntropyLoss()
output.rumour_loss = loss_fct(output.rumour_logits.view(-1, self.rumor_num_labels), rumor_labels.view(-1))
output.rumour_loss = loss_fct(output.rumour_logits.view(-1, self.rumor_num_labels).to(self.device), rumor_label_ids.view(-1).to(self.device))
output.attention_probs = attention_probs[:, 0, 0, :]
# sim_loss = self.cos_sim(stance_attention, rumor_attention)
# return loss + alpha*sim_loss

if stance_labels is not None: # for stance classification task
if stance_label_ids is not None: # for stance classification task
loss_fct = CrossEntropyLoss()
# Only keep active parts of the loss
if stance_label_mask is not None:
active_loss = stance_label_mask.view(-1) == 1
# print(active_loss)
# print(logits)
active_logits = output.stance_logits.view(-1, self.stance_num_labels)[active_loss]
active_labels = stance_labels.view(-1)[active_loss]
active_logits = output.stance_logits.view(-1, self.stance_num_labels)[active_loss].to(self.device)
active_labels = stance_label_ids.view(-1)[active_loss].to(self.device)
loss = loss_fct(active_logits, active_labels)
else:
loss = loss_fct(output.stance_logits.view(-1, self.stance_num_labels), stance_labels.view(-1))
loss = loss_fct(output.stance_logits.view(-1, self.stance_num_labels).to(self.device), stance_label_ids.view(-1).to(self.device))
output.stance_loss = loss

return output
20 changes: 15 additions & 5 deletions sgnlp/models/dual_bert/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

from .modeling import DualBert
from transformers import AdamW
from .optimization import BertAdam
from .preprocess import prepare_data_for_training, InputExample, DualBertPreprocessor
from sklearn.metrics import precision_recall_fscore_support

Expand Down Expand Up @@ -360,6 +361,15 @@ def train_custom_dual_bert(train_config: CustomDualBertTrainConfig, model_config
optimizer_grouped_parameters,
lr=train_config.learning_rate,
)
# For testing with original optimizer
# optimizer = BertAdam(optimizer_grouped_parameters,
# lr=train_config.learning_rate,
# warmup=0.1,
# t_total=t_total)
# stance_optimizer = BertAdam(optimizer_grouped_parameters,
# lr=train_config.learning_rate,
# warmup=0.1,
# t_total=t_total)

global_step = 0
nb_tr_steps = 0
Expand Down Expand Up @@ -453,7 +463,7 @@ def train_custom_dual_bert(train_config: CustomDualBertTrainConfig, model_config
# ) = batch

# optimize rumor detection task
tmp_model_output = model(**rumor_batch, rumor_labels=rumor_batch["rumor_label_ids"])
tmp_model_output = model(**rumor_batch)
# tmp_model_output = model(
# input_ids1,
# segment_ids1,
Expand Down Expand Up @@ -499,7 +509,7 @@ def train_custom_dual_bert(train_config: CustomDualBertTrainConfig, model_config
global_step += 1

# optimize stance classification task
tmp_model_output = model(**stance_batch, stance_labels=stance_batch["stance_label_ids"])
tmp_model_output = model(**stance_batch)
# tmp_model_output = model(
# stance_input_ids1,
# stance_segment_ids1,
Expand Down Expand Up @@ -630,7 +640,7 @@ def train_custom_dual_bert(train_config: CustomDualBertTrainConfig, model_config
desc="Evaluating"):
with torch.no_grad():
rumor_batch = {k: v.to(device) for k, v in rumor_batch.items()}
tmp_model_output = model(**rumor_batch, rumor_labels=rumor_batch["rumor_label_ids"])
tmp_model_output = model(**rumor_batch)
tmp_eval_loss = tmp_model_output.rumour_loss
logits = tmp_model_output.rumour_logits
stance_logits = tmp_model_output.stance_logits
Expand Down Expand Up @@ -799,7 +809,7 @@ def train_custom_dual_bert(train_config: CustomDualBertTrainConfig, model_config
desc="Evaluating"):

with torch.no_grad():
tmp_model_output = model(**rumor_batch, rumor_labels=rumor_batch["rumor_label_ids"])
tmp_model_output = model(**rumor_batch)
# tmp_model_output = model(
# input_ids1,
# segment_ids1,
Expand Down Expand Up @@ -1054,7 +1064,7 @@ def train_custom_dual_bert(train_config: CustomDualBertTrainConfig, model_config
# label_mask = label_mask.to(device)

with torch.no_grad():
tmp_model_output = model(**batch, rumor_labels=batch["rumor_label_ids"])
tmp_model_output = model(**batch)
# tmp_model_output = model(
# input_ids1,
# segment_ids1,
Expand Down
2 changes: 1 addition & 1 deletion sgnlp/models/dual_bert/train_config.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"rumor_num_labels":3,
"stance_num_labels":5,
"task_name2": "semeval17_stance",
"output_dir": "/polyaxon-data/workspace/atenzer/CHT_demo/output_release/semeval17_multitask_output_DB_28_02/",
"output_dir": "/polyaxon-data/workspace/atenzer/CHT_demo/output_release_01_03/semeval17_multitask_output_DB/",
"max_seq_length": 512,
"do_train":true,
"do_eval": true,
Expand Down

0 comments on commit 109dd40

Please sign in to comment.