Skip to content

Commit

Permalink
[#42] postprocssor refactored to multiple inputs (usage completed)
Browse files Browse the repository at this point in the history
  • Loading branch information
atenzer committed Feb 15, 2022
1 parent 0d1293a commit dd9a41a
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
6 changes: 6 additions & 0 deletions demo_api/coupled_hierarchical_transformer/usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
# ]

examples = [
[
"#4U9525: Robin names Andreas Lubitz as the copilot in the flight deck who crashed the aircraft.",
"@thatjohn @mschenk",
"@thatjohn Have they named the pilot?",
],
[
"#4U9525: Robin names Andreas Lubitz as the copilot in the flight deck who crashed the aircraft.",
"@thatjohn @mschenk",
Expand All @@ -51,3 +56,4 @@

model_output = model(**model_inputs)
output = postprocessor(model_output, model_inputs["stance_label_mask"])
print(output)
20 changes: 10 additions & 10 deletions sgnlp/models/coupled_hierarchical_transformer/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ def __init__(self, rumour_labels=["FR", "TR", "UR"], stance_labels=["PAD", "B-DE
self.rumor_labels = rumour_labels
self.stance_labels = stance_labels

def __call__(self, model_output: DualBertModelOutput, stance_label_mask):
rumour_label_idx = np.argmax(model_output.rumour_logits.detach().cpu().numpy())
rumour_label = self.rumor_labels[rumour_label_idx]
def __call__(self, model_outputs: [DualBertModelOutput], stance_label_mask):

stance_label_idx = torch.argmax(F.log_softmax(model_output.stance_logits, dim=2), dim=2)
# stance_label = self.stance_labels[stance_label_idx]
rumor_labels = []
for rumor_logits in model_outputs.rumour_logits:
rumour_label_idx = np.argmax(rumor_logits.detach().cpu().numpy())
rumor_labels.append(self.rumor_labels[rumour_label_idx])

stance_preds = []
stance_labels = []
stance_label_idx = torch.argmax(F.log_softmax(model_outputs.stance_logits, dim=2), dim=2)
stance_label_mask = stance_label_mask.to("cpu").numpy()
for i, mask in enumerate(stance_label_mask):
temp_2 = []
Expand All @@ -26,10 +27,9 @@ def __call__(self, model_output: DualBertModelOutput, stance_label_mask):
temp_2.append(self.stance_labels[stance_label_idx[i][j]])
else:
break
stance_preds.append(temp_2)

stance_labels.append(temp_2)

return {
"rumour_label": rumour_label,
"stance_label": stance_preds
"rumor_labels": rumor_labels,
"stance_labels": stance_labels
}

0 comments on commit dd9a41a

Please sign in to comment.