diff --git a/demo_api/README.md b/demo_api/README.md index 8e7a126..10f64ad 100644 --- a/demo_api/README.md +++ b/demo_api/README.md @@ -2,12 +2,13 @@ ``` # From root folder of repository: docker build -t -f demo_api//Dockerfile demo_api/ - docker run -p 8000:8000 -E.g. +# Example: Production build docker build -t lsr -f demo_api/lsr/Dockerfile demo_api/ -docker run -p 8000:8000 lsr + +# Example: Dev build +docker build -t coherence_momentum -f demo_api/coherence_momentum/dev.Dockerfile . ``` ## Notes on dev vs prod build diff --git a/demo_api/coherence_momentum/api.py b/demo_api/coherence_momentum/api.py new file mode 100644 index 0000000..b510ca1 --- /dev/null +++ b/demo_api/coherence_momentum/api.py @@ -0,0 +1,46 @@ +from flask import request + +from demo_api.common import create_api +from sgnlp.models.coherence_momentum import ( + CoherenceMomentumModel, + CoherenceMomentumConfig, + CoherenceMomentumPreprocessor +) + +app = create_api(app_name=__name__, model_card_path="model_card/coherence_momentum.json") + +# Load processors and models +config = CoherenceMomentumConfig.from_pretrained( + "https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json" +) +model = CoherenceMomentumModel.from_pretrained( + "https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin", + config=config +) + +preprocessor = CoherenceMomentumPreprocessor(config.model_size, config.max_len) + +app.logger.info("Model initialization complete") + + +@app.route("/predict", methods=["POST"]) +def predict(): + req_body = request.get_json() + + text1 = req_body["text1"] + text2 = req_body["text2"] + + text1_tensor = preprocessor([text1]) + text2_tensor = preprocessor([text2]) + + text1_score = model.get_main_score(text1_tensor["tokenized_texts"]).item() + text2_score = model.get_main_score(text2_tensor["tokenized_texts"]).item() + + return { + "text1_score": text1_score, + "text2_score": text2_score + } + + +if __name__ == "__main__": + app.run() diff --git a/demo_api/coherence_momentum/dev.Dockerfile b/demo_api/coherence_momentum/dev.Dockerfile new file mode 100644 index 0000000..2b58fe8 --- /dev/null +++ b/demo_api/coherence_momentum/dev.Dockerfile @@ -0,0 +1,14 @@ +FROM python:3.8-buster + +COPY ./demo_api /demo_api +COPY ./sgnlp /sgnlp +COPY ./setup.py /setup.py +COPY ./README.md /README.md + +RUN pip install -r /demo_api/coherence_momentum/requirements_dev.txt + +WORKDIR /demo_api/coherence_momentum + +RUN python -m download_pretrained + +CMD PYTHONPATH=../../ gunicorn -c ../gunicorn.conf.py \ No newline at end of file diff --git a/demo_api/coherence_momentum/download_pretrained.py b/demo_api/coherence_momentum/download_pretrained.py new file mode 100644 index 0000000..94532d9 --- /dev/null +++ b/demo_api/coherence_momentum/download_pretrained.py @@ -0,0 +1,9 @@ +from sgnlp.models.coherence_momentum import CoherenceMomentumModel, CoherenceMomentumConfig + +config = CoherenceMomentumConfig.from_pretrained( + "https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json" +) +model = CoherenceMomentumModel.from_pretrained( + "https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin", + config=config +) diff --git a/demo_api/coherence_momentum/model_card/coherence_momentum.json b/demo_api/coherence_momentum/model_card/coherence_momentum.json new file mode 100644 index 0000000..73bae60 --- /dev/null +++ b/demo_api/coherence_momentum/model_card/coherence_momentum.json @@ -0,0 +1,36 @@ +{ + "name": "CoherenceMomentum", + "languages": "English", + "description": "This is a neural network model that makes use of a momentum encoder and hard negative mining during training. This model is able to take in a piece of text and output a coherence score. The coherence score is only meant for comparison, i.e. it is only meaningful when used to compare between two texts, and the text with the higher coherence score is deemed to be more coherent by the model.", + "paper": { + "text": "Jwalapuram, P., Joty, S., & Lin, X. (2022). Rethinking Self-Supervision Objectives for Generalizable Coherence Modeling. Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), May 2022 (pp. 6044-6059).", + "url": "https://aclanthology.org/2022.acl-long.418/" + }, + "trainingDataset": "Permuted dataset derived from Linguistic Data Consortium's (LDC) Wall Street Journal (WSJ) dataset. Please contact the authors to get the dataset if you have a valid LDC license.", + "evaluationDataset": "Permuted dataset derived from Linguistic Data Consortium's (LDC) Wall Street Journal (WSJ) dataset. Please contact the authors to get the dataset if you have a valid LDC license.", + "evaluationScores": "0.988 accuracy on permuted WSJ dataset. 0.986 accuracy reported by authors on permuted WSJ dataset.", + "trainingConfig": { + "text": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json", + "url": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json" + }, + "trainingTime": "~24 hours for ~46000 steps (batch size of 1) on a single A100 GPU", + "modelWeights": { + "text": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin", + "url": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin" + }, + "modelInput": "A paragraph of text. During training, each positive example can be paired with one or more negative examples.", + "modelOutput": "Coherence score for the input text.", + "modelSize": "~930MB", + "inferenceInfo": "Not available.", + "usageScenarios": "Essay scoring, summarization, language generation.", + "originalCode": { + "text": "https://github.com/ntunlp/coherence-paradigm", + "url": "https://github.com/ntunlp/coherence-paradigm" + }, + "license": { + "text": "MIT License", + "url": "https://choosealicense.com/licenses/mit" + }, + "contact": "sg-nlp@aisingapore.org", + "additionalInfo": "Not applicable." +} \ No newline at end of file diff --git a/demo_api/coherence_momentum/requirements_dev.txt b/demo_api/coherence_momentum/requirements_dev.txt new file mode 100644 index 0000000..a32201d --- /dev/null +++ b/demo_api/coherence_momentum/requirements_dev.txt @@ -0,0 +1,3 @@ +-e. +flask +gunicorn \ No newline at end of file diff --git a/demo_api/coherence_momentum/usage.py b/demo_api/coherence_momentum/usage.py new file mode 100644 index 0000000..63ef79a --- /dev/null +++ b/demo_api/coherence_momentum/usage.py @@ -0,0 +1,33 @@ +from sgnlp.models.coherence_momentum import CoherenceMomentumModel, CoherenceMomentumConfig, \ + CoherenceMomentumPreprocessor + +config = CoherenceMomentumConfig.from_pretrained( + "https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json" +) +model = CoherenceMomentumModel.from_pretrained( + "https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin", + config=config +) + +preprocessor = CoherenceMomentumPreprocessor(config.model_size, config.max_len) + +text1 = "Companies listed below reported quarterly profit substantially different from the average of analysts ' " \ + "estimates . The companies are followed by at least three analysts , and had a minimum five-cent change in " \ + "actual earnings per share . Estimated and actual results involving losses are omitted . The percent " \ + "difference compares actual profit with the 30-day estimate where at least three analysts have issues " \ + "forecasts in the past 30 days . Otherwise , actual profit is compared with the 300-day estimate . " \ + "Source : Zacks Investment Research" +text2 = "The companies are followed by at least three analysts , and had a minimum five-cent change in actual " \ + "earnings per share . The percent difference compares actual profit with the 30-day estimate where at least " \ + "three analysts have issues forecasts in the past 30 days . Otherwise , actual profit is compared with the " \ + "300-day estimate . Source : Zacks Investment Research. Companies listed below reported quarterly profit " \ + "substantially different from the average of analysts ' estimates . Estimated and actual results involving " \ + "losses are omitted ." + +text1_tensor = preprocessor([text1]) +text2_tensor = preprocessor([text2]) + +text1_score = model.get_main_score(text1_tensor["tokenized_texts"]).item() +text2_score = model.get_main_score(text2_tensor["tokenized_texts"]).item() + +print(text1_score, text2_score) diff --git a/demo_api/lif_3way_ap/model_card/lif_3way_ap.json b/demo_api/lif_3way_ap/model_card/lif_3way_ap.json index a66c09d..a482fc8 100644 --- a/demo_api/lif_3way_ap/model_card/lif_3way_ap.json +++ b/demo_api/lif_3way_ap/model_card/lif_3way_ap.json @@ -16,7 +16,8 @@ }, "evaluationScores": "0.745 F1 on test_i dataset. 0.75 F1 reported by authors in paper on test_i dataset.", "trainingConfig": { - "text": "https://storage.googleapis.com/sgnlp/models/lif_3way_ap/config.json" + "text": "https://storage.googleapis.com/sgnlp/models/lif_3way_ap/config.json", + "url": "https://storage.googleapis.com/sgnlp/models/lif_3way_ap/config.json" }, "trainingTime": "~12 hours for 13 epochs on a single V100 GPU.", "modelWeights": { diff --git a/jsonnet/demo-api.jsonnet b/jsonnet/demo-api.jsonnet index 90a1185..9ee18df 100644 --- a/jsonnet/demo-api.jsonnet +++ b/jsonnet/demo-api.jsonnet @@ -16,6 +16,24 @@ local build_and_push_staging(module_name, image_name) = { ], }; +local build_and_push_dev_staging(module_name, image_name) = { + image: "registry.aisingapore.net/sg-nlp/sg-nlp-runner:latest", + stage: "build_and_push_staging", + tags: [ + "on-prem", + "dind", + ], + when: "manual", + script: [ + "echo 'Logging in to AISG Docker Registry...'", + "echo $STG_REGISTRY_PASSWORD | docker login registry.aisingapore.net -u $STG_DOCKER_USER --password-stdin", + "echo 'Building and pushing image...'", + "docker build --no-cache -t %s -f demo_api/%s/dev.Dockerfile ." % [module_name, module_name], + "docker tag %s registry.aisingapore.net/sg-nlp/%s:latest" % [module_name, image_name], + "docker push registry.aisingapore.net/sg-nlp/%s:latest" % image_name, + ], +}; + local build_and_push_docs_staging() = { image: "python:3.8.11-slim", stage: "build_and_push_staging", @@ -154,6 +172,15 @@ local api_names = { } }; +// To deploy dev builds into production (for beta public testing) +local dev_api_names = { + "coherence_momentum": { + module_name: "coherence_momentum", + image_name: "coherence-momentum", + deployment_name: "coherence-momentum" + } +}; + { "stages": [ "build_and_push_staging", @@ -166,6 +193,11 @@ local api_names = { [api_names[key]["module_name"] + "_build_and_push_staging"]: build_and_push_staging(api_names[key]["module_name"], api_names[key]["image_name"]) for key in std.objectFields(api_names) +} + { + // Build and push dev staging + [dev_api_names[key]["module_name"] + "_build_and_push_dev_staging"]: + build_and_push_dev_staging(dev_api_names[key]["module_name"], dev_api_names[key]["image_name"]) + for key in std.objectFields(dev_api_names) } + { // Restart kubernetes staging [api_names[key]["module_name"] + "_restart_kubernetes_staging"]: diff --git a/jsonnet/dev-demo-api.jsonnet b/jsonnet/dev-demo-api.jsonnet index 3239732..25ac1ba 100644 --- a/jsonnet/dev-demo-api.jsonnet +++ b/jsonnet/dev-demo-api.jsonnet @@ -59,6 +59,11 @@ local api_names = { module_name: "ufd", image_name: "ufd", deployment_name: "ufd" + }, + "coherence_momentum": { + module_name: "coherence_momentum", + image_name: "coherence-momentum", + deployment_name: "coherence-momentum" } }; diff --git a/polyaxon/coherence_momentum/model-training.Dockerfile b/polyaxon/coherence_momentum/model-training.Dockerfile new file mode 100644 index 0000000..44f96d4 --- /dev/null +++ b/polyaxon/coherence_momentum/model-training.Dockerfile @@ -0,0 +1,15 @@ +FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-devel + +ARG REPO_DIR="." +ARG PROJECT_USER="aisg" +ARG HOME_DIR="/home/$PROJECT_USER" + +COPY $REPO_DIR nlp-hub-gcp +WORKDIR $REPO_DIR/nlp-hub-gcp + +RUN pip install -r polyaxon/coherence_momentum/requirements.txt +RUN groupadd -g 2222 $PROJECT_USER && useradd -u 2222 -g 2222 -m $PROJECT_USER +RUN chown -R 2222:2222 $HOME_DIR && \ + rm /bin/sh && ln -s /bin/bash /bin/sh +USER 2222 + diff --git a/polyaxon/coherence_momentum/polyaxon-experiment-nomig.yml b/polyaxon/coherence_momentum/polyaxon-experiment-nomig.yml new file mode 100644 index 0000000..84a62a1 --- /dev/null +++ b/polyaxon/coherence_momentum/polyaxon-experiment-nomig.yml @@ -0,0 +1,62 @@ +version: 1.1 +kind: component +name: train-model +description: Job for training a predictive model using GPU. +tags: [model_training] +inputs: + - name: SA_CRED_PATH + description: Path to credential file for GCP service account. + isOptional: true + type: str + value: /var/secret/cloud.google.com/gcp-service-account.json + toEnv: GOOGLE_APPLICATION_CREDENTIALS + - name: WORKING_DIR + description: The working directory for the job to run in. + isOptional: true + value: /home/aisg/nlp-hub-gcp + type: str + - name: TRAIN_CONFIG_FILE_PATH + description: Config file path. + type: str + isOptional: false + - name: MODEL_CONFIG_FILE_PATH + description: Config file path. + type: str + isOptional: false +run: + kind: job + connections: [fstore-pvc] + environment: + imagePullSecrets: ["gcp-imagepullsecrets"] + tolerations: + - effect: NoSchedule + key: nvidia.com/gpu + operator: Equal + value: present + - effect: NoSchedule + key: nomig + operator: Equal + value: present + volumes: + - name: gcp-service-account + secret: + secretName: "gcp-sa-credentials" + container: + image: asia.gcr.io/nlp-hub/coherence-paradigm-refactored:0.6 + imagePullPolicy: IfNotPresent + workingDir: "{{ WORKING_DIR }}" + command: ["/bin/bash","-c"] + args: [ + "python -m sgnlp.models.coherence_momentum.train \ + --train_config_file {{ TRAIN_CONFIG_FILE_PATH }} \ + --model_config_file {{ MODEL_CONFIG_FILE_PATH }} + " + ] + resources: + requests: + nvidia.com/gpu: 1 + limits: + nvidia.com/gpu: 1 + volumeMounts: + - name: gcp-service-account + mountPath: /var/secret/cloud.google.com diff --git a/polyaxon/coherence_momentum/requirements.txt b/polyaxon/coherence_momentum/requirements.txt new file mode 100644 index 0000000..13ca7f6 --- /dev/null +++ b/polyaxon/coherence_momentum/requirements.txt @@ -0,0 +1,21 @@ +pandas==1.1.5 +mlflow==1.22.0 +protobuf==3.20.* +pylint==2.6.0 +pytest-cov==2.10.1 +pyyaml==5.4.1 +python-json-logger==2.0.2 +polyaxon==1.11.3 +google-cloud-storage==1.43.0 +hydra-core==1.1.1 +hydra-optuna-sweeper==1.1.1 +optuna==2.10.0 +fastapi==0.70.1 +uvicorn[standard]==0.14.0 +gunicorn==20.1.0 +nltk +scikit-learn +torchtext +transformers +sentencepiece +-e . \ No newline at end of file diff --git a/sgnlp/models/coherence_momentum/__init__.py b/sgnlp/models/coherence_momentum/__init__.py new file mode 100644 index 0000000..9599aa0 --- /dev/null +++ b/sgnlp/models/coherence_momentum/__init__.py @@ -0,0 +1,3 @@ +from .modeling import CoherenceMomentumModel +from .config import CoherenceMomentumConfig +from .preprocess import CoherenceMomentumPreprocessor diff --git a/sgnlp/models/coherence_momentum/config.py b/sgnlp/models/coherence_momentum/config.py new file mode 100755 index 0000000..06f63d4 --- /dev/null +++ b/sgnlp/models/coherence_momentum/config.py @@ -0,0 +1,26 @@ +from transformers import PretrainedConfig + + +class CoherenceMomentumConfig(PretrainedConfig): + def __init__( + self, + model_size: str = "base", + margin: float = 0.1, + num_negs: int = 5, + max_len: int = 600, + num_rank_negs: int = 50, + momentum_coefficient: float = 0.9999999, + queue_size: int = 1000, + contrastive_loss_weight: float = 0.85, + **kwargs + ): + super().__init__(**kwargs) + + self.model_size = model_size + self.margin = margin + self.num_negs = num_negs + self.max_len = max_len + self.num_rank_negs = num_rank_negs + self.momentum_coefficient = momentum_coefficient + self.queue_size = queue_size + self.contrastive_loss_weight = contrastive_loss_weight diff --git a/sgnlp/models/coherence_momentum/model_config.json b/sgnlp/models/coherence_momentum/model_config.json new file mode 100644 index 0000000..0b44fbe --- /dev/null +++ b/sgnlp/models/coherence_momentum/model_config.json @@ -0,0 +1,10 @@ +{ + "contrastive_loss_weight": 0.85, + "margin": 0.1, + "max_len": 600, + "model_size": "base", + "momentum_coefficient": 0.9999999, + "num_negs": 5, + "num_rank_negs": 50, + "queue_size": 1000 +} \ No newline at end of file diff --git a/sgnlp/models/coherence_momentum/modeling.py b/sgnlp/models/coherence_momentum/modeling.py new file mode 100644 index 0000000..575d146 --- /dev/null +++ b/sgnlp/models/coherence_momentum/modeling.py @@ -0,0 +1,108 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import XLNetModel, XLNetConfig +from transformers import PreTrainedModel +from .config import CoherenceMomentumConfig + + +class CoherenceMomentumPreTrainedModel(PreTrainedModel): + config_class = CoherenceMomentumConfig + base_model_prefix = "coherence_momentum" + + +class CoherenceMomentumModel(CoherenceMomentumPreTrainedModel): + def __init__(self, config): + + super().__init__(config) + self.momentum_coefficient = config.momentum_coefficient + + self.encoder_name = f"xlnet-{config.model_size}-cased" + self.encoder_config = XLNetConfig.from_pretrained(self.encoder_name) + self.main_encoder = XLNetModel(self.encoder_config) + self.momentum_encoder = XLNetModel(self.encoder_config) + + if config.model_size == "base": + hidden_size = 768 + elif config.model_size == "large": + hidden_size = 1024 + + self.queue = [] + self.queue_size = config.queue_size + self.con_loss_weight = config.contrastive_loss_weight + self.num_negs = config.num_negs + self.margin = config.margin + self.cosim = nn.CosineSimilarity() + self.sub_margin = lambda z: z - config.margin + + self.conlinear = nn.Linear(hidden_size, 1) + + def init_encoders(self): + self.main_encoder = XLNetModel.from_pretrained(self.encoder_name) + self.momentum_encoder = XLNetModel.from_pretrained(self.encoder_name) + + def get_main_score(self, doc): + rep = self.main_encoder(input_ids=doc).last_hidden_state[:, -1, :] + score = self.conlinear(rep).view(-1) + return score + + def get_momentum_rep(self, doc): + rep = self.momentum_encoder(input_ids=doc).last_hidden_state[:, -1, :] + return rep.detach() + + def get_cos_sim(self, pos_rep, pos_slice): + pos_sim = self.cosim(pos_rep, pos_slice) + neg_sims = [self.cosim(pos_rep, neg_x.view(1, -1)) for neg_x in self.queue] + return pos_sim, neg_sims + + def update_momentum_encoder(self): + with torch.no_grad(): + for main, moco in zip( + self.main_encoder.parameters(), self.momentum_encoder.parameters() + ): + moco.data = (moco.data * self.momentum_coefficient) + ( + main.data * (1 - self.momentum_coefficient) + ) + + def forward(self, pos_doc, pos_slice, neg_docs): + pos_rep = self.main_encoder(input_ids=pos_doc).last_hidden_state[:, -1, :] + pos_score = self.conlinear(pos_rep).view(-1) + + pos_slice_rep = self.get_momentum_rep(pos_slice) + + neg_scores = list(map(self.get_main_score, list(neg_docs))) + neg_moco_rep = list(map(self.get_momentum_rep, list(neg_docs))) + + if len(self.queue) >= self.queue_size: # global negative queue size + del self.queue[: self.num_negs] + self.queue.extend(neg_moco_rep[0]) + + pos_sim, neg_sims = self.get_cos_sim(pos_rep, pos_slice_rep) + + sim_contra_loss = self.sim_contrastive_loss(pos_sim, neg_sims) + contra_loss = self.contrastive_loss(pos_score, neg_scores[0]) + + full_loss = (self.con_loss_weight * contra_loss) + ( + (1 - self.con_loss_weight) * sim_contra_loss + ) + + return full_loss + + def eval_forward(self, pos_doc, neg_docs): + pos_score = self.get_main_score(pos_doc) + neg_scores = torch.stack(list(map(self.get_main_score, list(neg_docs)))) + return pos_score.detach(), neg_scores[0].detach() + + def sim_contrastive_loss(self, pos_sim, neg_sims): + neg_sims_sub = torch.stack(list(map(self.sub_margin, neg_sims))).view(-1) + all_sims = torch.cat((neg_sims_sub, pos_sim), dim=-1) + lsmax = -1 * F.log_softmax(all_sims, dim=-1) + loss = lsmax[-1] + return loss + + def contrastive_loss(self, pos_score, neg_scores): + neg_scores_sub = torch.stack(list(map(self.sub_margin, neg_scores))) + all_scores = torch.cat((neg_scores_sub, pos_score), dim=-1) + lsmax = -1 * F.log_softmax(all_scores, dim=-1) + pos_loss = lsmax[-1] + return pos_loss diff --git a/sgnlp/models/coherence_momentum/preprocess.py b/sgnlp/models/coherence_momentum/preprocess.py new file mode 100644 index 0000000..8d6acb5 --- /dev/null +++ b/sgnlp/models/coherence_momentum/preprocess.py @@ -0,0 +1,48 @@ +from typing import List + +import torch +from transformers import XLNetTokenizer + + +class CoherenceMomentumPreprocessor: + def __init__(self, model_size, max_len, tokenizer=None): + if tokenizer is not None: + self.tokenizer = tokenizer + else: + self.tokenizer = XLNetTokenizer.from_pretrained(f"xlnet-{model_size}-cased") + + self.max_len = max_len + + def __call__(self, texts: List[str]): + """ + + Args: + texts (List[str]): List of input texts + + Returns: + Dict[str, str]: Returns a dictionary with the following key-values: + "tokenized_texts": (torch.tensor) Tensors of tokenized ids of input texts + """ + + result = [] + for text in texts: + tokens = self.tokenizer.tokenize(text) + ids = self.tokenizer.convert_tokens_to_ids(tokens) + ids = self.pad_ids(ids) + ids = self.tokenizer.build_inputs_with_special_tokens(ids) + result.append(torch.tensor(ids)) + + return {"tokenized_texts": torch.stack(result)} + + def pad_ids(self, ids): + if len(ids) < self.max_len: + padding_size = self.max_len - len(ids) + padding = [ + self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) + for _ in range(padding_size) + ] + ids = ids + padding + else: + ids = ids[: self.max_len] + + return ids diff --git a/sgnlp/models/coherence_momentum/train.py b/sgnlp/models/coherence_momentum/train.py new file mode 100755 index 0000000..ba3284e --- /dev/null +++ b/sgnlp/models/coherence_momentum/train.py @@ -0,0 +1,402 @@ +import argparse +import pickle +import time +import os +import datetime +import random +import shutil + +import torch +from torch.utils.data import Dataset, DataLoader, SequentialSampler +from transformers import AdamW, XLNetTokenizer +from torch.optim.swa_utils import SWALR + +from .modeling import CoherenceMomentumModel +from .config import CoherenceMomentumConfig +from .train_config import CoherenceMomentumTrainConfig +from sgnlp.utils.train_config import load_train_config + + +class CoherenceMomentumDataset(Dataset): + def __init__(self, fname, model_size, device, datatype, negs, max_len): + self.fname = fname + self.device = device + self.data = pickle.load(open(fname, "rb")) + random.shuffle(self.data) + self.tokenizer = XLNetTokenizer.from_pretrained( + "xlnet-{}-cased".format(model_size) + ) + self.datatype = datatype + self.negs = negs + self.max_len = max_len + + def pad_ids(self, ids): + if len(ids) < self.max_len: + padding_size = self.max_len - len(ids) + padding = [ + self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token) + for _ in range(padding_size) + ] + ids = ids + padding + else: + ids = ids[: self.max_len] + + return ids + + def prepare_data(self, idx): + pos_doc = self.data[idx]["pos"] + + if self.datatype == "single": + neg_docs = [self.data[idx]["neg"]] + elif self.datatype == "multiple": + neg_docs = self.data[idx]["negs"][: self.negs] + else: + raise Exception("Unexpected datatype") + + pos_span = pos_doc + pos_span = " ".join(pos_span) + pos_tokens = self.tokenizer.tokenize(pos_span) + pos_ids = self.tokenizer.convert_tokens_to_ids(pos_tokens) + pos_ids = self.pad_ids(pos_ids) + + neg_span_list = [] + for neg_doc in neg_docs: + neg_span = neg_doc + neg_span = " ".join(neg_span) + neg_tokens = self.tokenizer.tokenize(neg_span) + neg_ids = self.tokenizer.convert_tokens_to_ids(neg_tokens) + neg_ids = self.pad_ids(neg_ids) + neg_input = self.tokenizer.build_inputs_with_special_tokens(neg_ids) + + neg_span_list.append(torch.tensor(neg_input)) + + pos_input = self.tokenizer.build_inputs_with_special_tokens(pos_ids) + + return torch.tensor(pos_input).to(self.device), torch.stack(neg_span_list).to( + self.device + ) + + def get_slice(self, doc): + try: + end = random.choice(range(4, len(doc))) + return doc[:end] + except: + return doc + + def prepare_train_data(self, data_list, num_negs): + train_list = [] + for each_item in data_list: + train_list.append(list(self.prepare_each_item(each_item, num_negs))) + return train_list + + def prepare_each_item(self, train_data_item, num_negs): + pos_doc = train_data_item["pos"] + if self.datatype == "single": + neg_docs = [train_data_item["neg"]] + elif self.datatype == "multiple": + neg_docs = train_data_item["negs"][:num_negs] + + pos_span = pos_doc + pos_span = " ".join(pos_span) + pos_tokens = self.tokenizer.tokenize(pos_span) + pos_ids = self.tokenizer.convert_tokens_to_ids(pos_tokens) + pos_ids = self.pad_ids(pos_ids) + + pos_slice = " ".join(self.get_slice(pos_doc)) + slice_tokens = self.tokenizer.tokenize(pos_slice) + slice_ids = self.tokenizer.convert_tokens_to_ids(slice_tokens) + slice_ids = self.pad_ids(slice_ids) + + neg_span_list = [] + for neg_doc in neg_docs: + neg_span = neg_doc + neg_span = " ".join(neg_span) + neg_tokens = self.tokenizer.tokenize(neg_span) + neg_ids = self.tokenizer.convert_tokens_to_ids(neg_tokens) + neg_ids = self.pad_ids(neg_ids) + + neg_input = self.tokenizer.build_inputs_with_special_tokens(neg_ids) + + neg_span_list.append(torch.tensor(neg_input)) + + pos_input = self.tokenizer.build_inputs_with_special_tokens(pos_ids) + slice_input = self.tokenizer.build_inputs_with_special_tokens(slice_ids) + + pos_tensor = torch.tensor(pos_input).unsqueeze(0).to(self.device) + slice_tensor = torch.tensor(slice_input).unsqueeze(0).to(self.device) + neg_tensor_stack = torch.stack(neg_span_list).unsqueeze(0).to(self.device) + + return pos_tensor, slice_tensor, neg_tensor_stack + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + return self.prepare_data(idx) + + +class LoadData: + def __init__(self, fname, batch_size, model_size, device, datatype, negs, max_len): + self.fname = fname + self.batch_size = batch_size + self.dataset = CoherenceMomentumDataset( + fname, model_size, device, datatype, negs, max_len + ) + + def data_loader(self): + data_sampler = SequentialSampler(self.dataset) + loader = DataLoader( + dataset=self.dataset, sampler=data_sampler, batch_size=self.batch_size + ) + return loader + + +class TrainMomentumModel: + def __init__(self, model_config_path, train_config_path): + self.model_config = CoherenceMomentumConfig.from_pretrained(model_config_path) + self.train_config = load_train_config( + CoherenceMomentumTrainConfig, train_config_path + ) + + self.model_size = self.model_config.model_size + self.num_negs = self.model_config.num_negs + self.max_len = self.model_config.max_len + self.rank_negs = self.model_config.num_rank_negs + + self.dev_file = self.train_config.dev_file + if self.train_config.test_file: + self.test_file = self.train_config.test_file + else: + self.test_file = self.train_config.dev_file + self.output_dir = ( + self.train_config.output_dir + + "-" + + datetime.datetime.now().strftime("%Y%m%d%H%M%S") + ) + self.datatype = self.train_config.data_type + self.eval_interval = self.train_config.eval_interval + self.seed = self.train_config.seed + self.batch_size = self.train_config.batch_size + self.train_steps = self.train_config.train_steps + self.num_checkpoints = self.train_config.num_checkpoints + self.best_checkpoints = [] # List of tuples of (accuracy, path) + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + torch.manual_seed(self.seed) + torch.cuda.manual_seed_all(self.seed) + + self.xlnet_model = CoherenceMomentumModel(self.model_config) + self.xlnet_model.init_encoders() + self.xlnet_model = self.xlnet_model.to(self.device) + + self.optimizer = AdamW( + self.xlnet_model.parameters(), lr=self.train_config.lr_start + ) + self.scheduler = SWALR( + self.optimizer, + anneal_strategy="linear", + anneal_epochs=self.train_config.lr_anneal_epochs, + swa_lr=self.train_config.lr_end, + ) + + self.total_loss = 0.0 + self.bestacc = 0.0 + + def get_ranked_negs(self, neg_scores): + ranked_idx = sorted( + range(len(neg_scores)), key=neg_scores.__getitem__, reverse=True + ) + hard_negs = ranked_idx[: self.num_negs] + return hard_negs + + def get_next_train_data(self, processed_exploration_data): + self.xlnet_model.eval() + + next_train_data = [] + with torch.no_grad(): + for i, each_data in enumerate(processed_exploration_data): + try: + pos_input, slice_input, neg_input = each_data + except Exception as e: + print(e) + continue + + pos_score, neg_scores = self.xlnet_model.eval_forward( + pos_input, neg_input + ) + pos_score = pos_score.to(torch.device("cpu")) + neg_scores = neg_scores.to(torch.device("cpu")) + + next_neg_idx = self.get_ranked_negs(neg_scores) + + if len(next_neg_idx) < self.num_negs: + continue + + neg_data_list = torch.stack( + [neg_input[0][x] for x in next_neg_idx] + ).unsqueeze(0) + next_train_data.append([pos_input, slice_input, neg_data_list]) + + return next_train_data + + def hard_negs_controller(self): + start = time.time() + train_data = CoherenceMomentumDataset( + self.train_config.train_file, + self.model_size, + self.device, + self.datatype, + self.num_negs, + self.max_len, + ) + init_train_data = train_data.data[: self.train_steps] + total_iterations = len(train_data.data) // self.train_steps + + for iteration_index in range(total_iterations): + full_time = time.asctime(time.localtime(time.time())) + + print( + "ITERATION: {} TIME: {} LOSS: {}".format( + iteration_index, full_time, self.total_loss + ) + ) + self.total_loss = 0.0 + + if iteration_index == 0: + processed_train_data_list = train_data.prepare_train_data( + init_train_data, self.num_negs + ) + self.train_xlnet_model(processed_train_data_list) + else: + start_index = iteration_index * self.train_steps + end_index = start_index + self.train_steps + + processed_explore_data_list = train_data.prepare_train_data( + train_data.data[start_index:end_index], self.rank_negs + ) + next_train_data = self.get_next_train_data(processed_explore_data_list) + self.train_xlnet_model(next_train_data) + + if (self.train_steps * (iteration_index + 1)) % self.eval_interval == 0: + self.scheduler.step() + self.eval_model( + self.dev_file, self.train_steps * (iteration_index + 1), start + ) + + self.eval_model(self.dev_file, self.train_steps * (iteration_index + 1), start) + + def train_xlnet_model(self, train_loader): + self.xlnet_model.train() + + for step, data in enumerate(train_loader): + + self.optimizer.zero_grad() + + try: + pos_input, slice_input, neg_input = data + except Exception as e: + print(e) + continue + + combined_loss = self.xlnet_model(pos_input, slice_input, neg_input) + combined_loss.backward() + + self.xlnet_model.update_momentum_encoder() + self.optimizer.step() + + self.total_loss += combined_loss.item() + + def eval_model(self, data_file, step, start): + self.xlnet_model.eval() + test_data = LoadData( + data_file, + self.batch_size, + self.model_size, + self.device, + self.datatype, + self.num_negs, + self.max_len, + ) + test_loader = test_data.data_loader() + + correct = 0.0 + total = 0.0 + + with torch.no_grad(): + for data in test_loader: + try: + pos_input, neg_inputs = data + except Exception as e: + print(e) + continue + + pos_score, neg_scores = self.xlnet_model.eval_forward( + pos_input, neg_inputs + ) + try: + max_neg_score = torch.max(neg_scores, -1).values + except: + max_neg_score = max(neg_scores) + + if pos_score > max_neg_score: + correct += 1.0 + total += 1.0 + + self.xlnet_model.train() + end = time.time() + full_time = time.asctime(time.localtime(end)) + acc = correct / total + if data_file == self.dev_file: + print( + "DEV EVAL Time: {} Elapsed: {} Steps: {} Acc: {}".format( + full_time, end - start, step, acc + ) + ) + if step > 0: + self.bestacc = acc + self.save_model(self.output_dir, step, acc) + elif data_file == self.test_file: + print( + "Please evaluate the test file separately with the best saved checkpoint." + ) + print( + "TEST EVAL Time: {} Steps: {} Acc: {}".format( + full_time, end - start, step, acc + ) + ) + + return + + def save_model(self, output_dir, step, accuracy): + if not os.path.isdir(output_dir): + os.mkdir(output_dir) + model_path = os.path.join( + output_dir, + f"momentum_seed-{self.seed}_bs-{self.batch_size}_lr-{self.train_config.lr_start}" + f"_step-{step}_type-{self.model_size}_acc-{accuracy:.3f}", + ) + + if len(self.best_checkpoints) == 0: + self.xlnet_model.save_pretrained(model_path) + self.best_checkpoints.append((accuracy, model_path)) + elif accuracy > self.best_checkpoints[-1][0]: + self.xlnet_model.save_pretrained(model_path) + self.best_checkpoints.append((accuracy, model_path)) + self.best_checkpoints.sort(key=lambda x: x[0], reverse=True) + if len(self.best_checkpoints) > self.num_checkpoints: + _, dir_to_delete = self.best_checkpoints.pop() + shutil.rmtree(dir_to_delete, ignore_errors=True) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--train_config_file", type=str) + parser.add_argument("--model_config_file", type=str) + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + trainer = TrainMomentumModel(args.model_config_file, args.train_config_file) + trainer.hard_negs_controller() diff --git a/sgnlp/models/coherence_momentum/train_config.json b/sgnlp/models/coherence_momentum/train_config.json new file mode 100644 index 0000000..4a4f539 --- /dev/null +++ b/sgnlp/models/coherence_momentum/train_config.json @@ -0,0 +1,15 @@ +{ + "batch_size": 1, + "data_type": "multiple", + "dev_file": "permuted_wsj_dev_max-negs-100_size-4K", + "eval_file": "", + "eval_interval": 1000, + "lr_anneal_epochs": 50, + "lr_end": 1e-06, + "lr_start": 5e-06, + "output_dir": "outputs", + "seed": 100, + "test_file": "", + "train_file": "permuted_wsj_train_max-negs-100_size-46K", + "train_steps": 200 +} \ No newline at end of file diff --git a/sgnlp/models/coherence_momentum/train_config.py b/sgnlp/models/coherence_momentum/train_config.py new file mode 100644 index 0000000..c727962 --- /dev/null +++ b/sgnlp/models/coherence_momentum/train_config.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass, field + + +@dataclass +class CoherenceMomentumTrainConfig: + train_file: str = field(metadata={"help": "Train file path."}) + dev_file: str = field(metadata={"help": "Dev file path."}) + test_file: str = field(metadata={"help": "Test file path."}) + eval_file: str = field(metadata={"help": "Eval file path."}) + output_dir: str = field(metadata={"help": "Output directory."}) + DATA_TYPE_CHOICES = ["multiple", "single"] + data_type: str = field( + metadata={"choices": DATA_TYPE_CHOICES, "help": "Data format."} + ) + lr_start: float = field(default=5e-06) + lr_end: float = field(default=1e-06) + lr_anneal_epochs: int = field(default=50) + eval_interval: int = field(default=1000) + seed: int = field(default=100) + batch_size: int = field(default=1) + train_steps: int = field(default=200) + num_checkpoints: int = field( + default=5, metadata={"help": "Number of best checkpoints to save"} + ) + + def __post_init__(self): + assert ( + self.data_type in self.DATA_TYPE_CHOICES + ), f"Data type must be one of {self.DATA_TYPE_CHOICES}" diff --git a/sgnlp/models/emotion_entailment/modeling.py b/sgnlp/models/emotion_entailment/modeling.py index 2eb7c78..5d2b9fc 100644 --- a/sgnlp/models/emotion_entailment/modeling.py +++ b/sgnlp/models/emotion_entailment/modeling.py @@ -1,3 +1,6 @@ +from typing import Optional + +import torch from transformers import RobertaForSequenceClassification @@ -34,5 +37,28 @@ class RecconEmotionEntailmentModel(RobertaForSequenceClassification): def __init__(self, config): super().__init__(config) - def forward(self, **kwargs): - return super().forward(**kwargs) + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + token_type_ids: Optional[torch.LongTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) diff --git a/sgnlp/models/lsr/train.py b/sgnlp/models/lsr/train.py index eb32349..3be2137 100644 --- a/sgnlp/models/lsr/train.py +++ b/sgnlp/models/lsr/train.py @@ -60,8 +60,8 @@ def compute(self, input_theta=None): Args: input_theta (`optional`, `float`): - Prediction threshold. Provide a value between 0 to 1 if you want to compute the precision and recall - for that specific threshold. Otherwise the optimal based on f1 score will be computed for you. + Prediction threshold. Provide a value between 0 and 1 if you want to compute the precision and recall + for that specific threshold. Otherwise, the optimal based on f1 score will be computed for you. """ # Sorts in descending order by predicted value self.test_result.sort(key=lambda x: x[1], reverse=True) diff --git a/sgnlp/models/span_extraction/modeling.py b/sgnlp/models/span_extraction/modeling.py index 473585d..07f1033 100644 --- a/sgnlp/models/span_extraction/modeling.py +++ b/sgnlp/models/span_extraction/modeling.py @@ -1,3 +1,6 @@ +from typing import Optional + +import torch from transformers import BertForQuestionAnswering @@ -43,5 +46,30 @@ class RecconSpanExtractionModel(BertForQuestionAnswering): def __init__(self, config): super().__init__(config) - def forward(self, **kwargs): - return super().forward(**kwargs) + def forward( + self, + input_ids: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + position_ids: Optional[torch.Tensor] = None, + head_mask: Optional[torch.Tensor] = None, + inputs_embeds: Optional[torch.Tensor] = None, + start_positions: Optional[torch.Tensor] = None, + end_positions: Optional[torch.Tensor] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ): + return super().forward( + input_ids=input_ids, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + start_positions=start_positions, + end_positions=end_positions, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) diff --git a/sgnlp/utils/train_config.py b/sgnlp/utils/train_config.py new file mode 100644 index 0000000..e53f311 --- /dev/null +++ b/sgnlp/utils/train_config.py @@ -0,0 +1,7 @@ +import json + + +def load_train_config(config_class, json_file_path): + with open(json_file_path, "r") as f: + json_file = json.load(f) + return config_class(**json_file)