diff --git a/demo_api/dual_bert/usage.py b/demo_api/dual_bert/usage.py index 56a71c6..8703e8b 100644 --- a/demo_api/dual_bert/usage.py +++ b/demo_api/dual_bert/usage.py @@ -17,11 +17,6 @@ model.eval() examples = [ - [ - "#4U9525: Robin names Andreas Lubitz as the copilot in the flight deck who crashed the aircraft.", - "@thatjohn @mschenk", - "@thatjohn Have they named the pilot?", - ], [ "#4U9525: Robin names Andreas Lubitz as the copilot in the flight deck who crashed the aircraft.", "@thatjohn @mschenk", diff --git a/sgnlp/models/dual_bert/postprocess.py b/sgnlp/models/dual_bert/postprocess.py index 264a459..c420ed1 100644 --- a/sgnlp/models/dual_bert/postprocess.py +++ b/sgnlp/models/dual_bert/postprocess.py @@ -27,7 +27,7 @@ def __call__(self, model_outputs: [DualBertModelOutput], stance_label_mask): temp_2.append(self.stance_labels[stance_label_idx[i][j]]) else: break - stance_labels.append(temp_2) + stance_labels.append(temp_2[1:]) #first post should not have a stance label return { "rumor_labels": rumor_labels,