Skip to content

Commit

Permalink
[#42] Refactor modeling, training, and preprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
atenzer committed Feb 11, 2022
1 parent 9bec9d8 commit ee020a8
Show file tree
Hide file tree
Showing 7 changed files with 568 additions and 1,262 deletions.
38 changes: 25 additions & 13 deletions demo_api/coupled_hierarchical_transformer/usage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import torch
from transformers import BertConfig
from transformers import BertConfig, BertTokenizer

from sgnlp.models.coupled_hierarchical_transformer import (
DualBert,
DualBertPreprocessor,
DualBertConfig,
prepare_data_for_training
DualBertPreprocessor,
DualBertPostprocessor
)

# model_state_dict = torch.load("/Users/nus/Documents/Code/projects/SGnlp/sgnlp/output/pytorch_model.bin")
Expand All @@ -20,22 +20,34 @@
# )
#
# print("x")
from sgnlp.models.coupled_hierarchical_transformer.train import InputExample

preprocessor = DualBertPreprocessor()
config = DualBertConfig.from_pretrained("https://storage.googleapis.com/sgnlp/models/dual_bert/config.json")

config = DualBertConfig.from_pretrained("/Users/nus/Documents/Code/projects/SGnlp/sgnlp/output/config.json")
model = DualBert.from_pretrained("/Users/nus/Documents/Code/projects/SGnlp/sgnlp/output/pytorch_model.bin", config=config)
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
preprocessor = DualBertPreprocessor(config, tokenizer)
model = DualBert.from_pretrained("https://storage.googleapis.com/sgnlp/models/dual_bert/pytorch_model.bin",
config=config)
postprocessor = DualBertPostprocessor()

model.eval()

example = [
"Claim",
"Response 1",
"Response 2"
# example = [
# "#4U9525: Robin names Andreas Lubitz as the copilot in the flight deck who crashed the aircraft.",
# "@thatjohn @mschenk",
# "@thatjohn Have they named the pilot?",
# ]

examples = [
InputExample(text=[
"#4U9525: Robin names Andreas Lubitz as the copilot in the flight deck who crashed the aircraft.",
"@thatjohn @mschenk",
"@thatjohn Have they named the pilot?",
])
]

model_inputs = preprocessor([example])
model_inputs = preprocessor(examples)
# { model_param_1: ..., model_param2: ..., ...}

model(**model_inputs)

model_output = model(**model_inputs)
output = postprocessor(model_output, model_inputs["stance_label_mask"])
3 changes: 2 additions & 1 deletion sgnlp/models/coupled_hierarchical_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .config import DualBertConfig
from .modeling import DualBert
from .preprocess import prepare_data_for_training
from .preprocess import DualBertPreprocessor
from .postprocess import DualBertPostprocessor
Loading

0 comments on commit ee020a8

Please sign in to comment.