diff --git a/demo_api/dual_bert/api.py b/demo_api/dual_bert/api.py index 67ea036..2a0d1be 100644 --- a/demo_api/dual_bert/api.py +++ b/demo_api/dual_bert/api.py @@ -36,7 +36,7 @@ def predict(): outputs = postprocessor(model_output, model_inputs["stance_label_mask"]) - return {"rumor_label": outputs["rumor_labels"][0], + return {"rumor_labels": outputs["rumor_labels"][0], "stance_labels": outputs["stance_labels"][0], "text": examples[0]} diff --git a/sgnlp/models/dual_bert/postprocess.py b/sgnlp/models/dual_bert/postprocess.py index c8a4033..8cb213c 100644 --- a/sgnlp/models/dual_bert/postprocess.py +++ b/sgnlp/models/dual_bert/postprocess.py @@ -6,17 +6,17 @@ class DualBertPostprocessor: - def __init__(self, rumour_labels=["False Rumor", "True Rumor", "Unverified Rumor"], - stance_labels=["PAD", "Deny", "Support", "Query", "Comment"]): - self.rumor_labels = rumour_labels - self.stance_labels = stance_labels + def __init__(self, rumour_labels_list=["False Rumor", "True Rumor", "Unverified Rumor"], + stance_labels_list=["PAD", "Deny", "Support", "Query", "Comment"]): + self.rumor_labels_list = rumour_labels_list + self.stance_labels_list = stance_labels_list def __call__(self, model_outputs: [DualBertModelOutput], stance_label_mask): rumor_labels = [] for rumor_logits in model_outputs.rumour_logits: rumour_label_idx = np.argmax(rumor_logits.detach().cpu().numpy()) - rumor_labels.append(self.rumor_labels[rumour_label_idx]) + rumor_labels.append(self.rumor_labels_list[rumour_label_idx]) stance_labels = [] stance_label_idx = torch.argmax(F.log_softmax(model_outputs.stance_logits, dim=2), dim=2) @@ -25,7 +25,7 @@ def __call__(self, model_outputs: [DualBertModelOutput], stance_label_mask): temp_2 = [] for j, m in enumerate(mask): if m: - temp_2.append(self.stance_labels[stance_label_idx[i][j]]) + temp_2.append(self.stance_labels_list[stance_label_idx[i][j]]) else: break stance_labels.append(temp_2[1:]) # first post should not have a stance label