-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
141 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from flask import request | ||
|
||
from demo_api.common import create_api | ||
from sgnlp.models.coherence_momentum import ( | ||
CoherenceMomentumModel, | ||
CoherenceMomentumConfig, | ||
CoherenceMomentumPreprocessor | ||
) | ||
|
||
app = create_api(app_name=__name__, model_card_path="model_card/coherence_momentum.json") | ||
|
||
# Load processors and models | ||
config = CoherenceMomentumConfig.from_pretrained( | ||
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json" | ||
) | ||
model = CoherenceMomentumModel.from_pretrained( | ||
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin", | ||
config=config | ||
) | ||
|
||
preprocessor = CoherenceMomentumPreprocessor(config.model_size, config.max_len) | ||
|
||
app.logger.info("Model initialization complete") | ||
|
||
|
||
@app.route("/predict", methods=["POST"]) | ||
def predict(): | ||
req_body = request.get_json() | ||
|
||
text1 = req_body["text1"] | ||
text2 = req_body["text2"] | ||
|
||
text1_tensor = preprocessor([text1]) | ||
text2_tensor = preprocessor([text2]) | ||
|
||
text1_score = model.get_main_score(text1_tensor["tokenized_texts"]).item() | ||
text2_score = model.get_main_score(text2_tensor["tokenized_texts"]).item() | ||
|
||
return { | ||
"text1_score": text1_score, | ||
"text2_score": text2_score | ||
} | ||
|
||
|
||
if __name__ == "__main__": | ||
app.run() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
FROM python:3.8-buster | ||
|
||
COPY ./demo_api /demo_api | ||
COPY ./sgnlp /sgnlp | ||
COPY ./setup.py /setup.py | ||
COPY ./README.md /README.md | ||
|
||
RUN pip install -r /demo_api/coherence_momentum/requirements_dev.txt | ||
|
||
WORKDIR /demo_api/coherence_momentum | ||
|
||
RUN python -m download_pretrained | ||
|
||
CMD PYTHONPATH=../../ gunicorn -c ../gunicorn.conf.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from sgnlp.models.coherence_momentum import CoherenceMomentumModel, CoherenceMomentumConfig | ||
|
||
config = CoherenceMomentumConfig.from_pretrained( | ||
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json" | ||
) | ||
model = CoherenceMomentumModel.from_pretrained( | ||
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin", | ||
config=config | ||
) |
36 changes: 36 additions & 0 deletions
36
demo_api/coherence_momentum/model_card/coherence_momentum.json
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
{ | ||
"name": "CoherenceMomentum", | ||
"languages": "English", | ||
"description": "This is a neural network model that makes use of a momentum encoder and hard negative mining during training. This model is able to take in a piece of text and output a coherence score. The coherence score is only meant for comparison, i.e. it is only meaningful when used to compare between two texts, and the text with the higher coherence score is deemed to be more coherent by the model.", | ||
"paper": { | ||
"text": "Jwalapuram, P., Joty, S., & Lin, X. (2022). Rethinking Self-Supervision Objectives for Generalizable Coherence Modeling. Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), May 2022 (pp. 6044-6059).", | ||
"url": "https://aclanthology.org/2022.acl-long.418/" | ||
}, | ||
"trainingDataset": "Permuted dataset derived from Linguistic Data Consortium's (LDC) Wall Street Journal (WSJ) dataset. Please contact the authors to get the dataset if you have a valid LDC license.", | ||
"evaluationDataset": "Permuted dataset derived from Linguistic Data Consortium's (LDC) Wall Street Journal (WSJ) dataset. Please contact the authors to get the dataset if you have a valid LDC license.", | ||
"evaluationScores": "0.988 accuracy on permuted WSJ dataset. 0.986 accuracy reported by authors on permuted WSJ dataset.", | ||
"trainingConfig": { | ||
"text": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json", | ||
"url": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json" | ||
}, | ||
"trainingTime": "~24 hours for ~46000 steps (batch size of 1) on a single A100 GPU", | ||
"modelWeights": { | ||
"text": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin", | ||
"url": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin" | ||
}, | ||
"modelInput": "A paragraph of text. During training, each positive example can be paired with one or more negative examples.", | ||
"modelOutput": "Coherence score for the input text.", | ||
"modelSize": "~930MB", | ||
"inferenceInfo": "Not available.", | ||
"usageScenarios": "Essay scoring, summarization, language generation.", | ||
"originalCode": { | ||
"text": "https://github.com/ntunlp/coherence-paradigm", | ||
"url": "https://github.com/ntunlp/coherence-paradigm" | ||
}, | ||
"license": { | ||
"text": "MIT License", | ||
"url": "https://choosealicense.com/licenses/mit" | ||
}, | ||
"contact": "[email protected]", | ||
"additionalInfo": "Not applicable." | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
-e. | ||
flask | ||
gunicorn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from sgnlp.models.coherence_momentum import CoherenceMomentumModel, CoherenceMomentumConfig, \ | ||
CoherenceMomentumPreprocessor | ||
|
||
config = CoherenceMomentumConfig.from_pretrained( | ||
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json" | ||
) | ||
model = CoherenceMomentumModel.from_pretrained( | ||
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin", | ||
config=config | ||
) | ||
|
||
preprocessor = CoherenceMomentumPreprocessor(config.model_size, config.max_len) | ||
|
||
text1 = "Companies listed below reported quarterly profit substantially different from the average of analysts ' " \ | ||
"estimates . The companies are followed by at least three analysts , and had a minimum five-cent change in " \ | ||
"actual earnings per share . Estimated and actual results involving losses are omitted . The percent " \ | ||
"difference compares actual profit with the 30-day estimate where at least three analysts have issues " \ | ||
"forecasts in the past 30 days . Otherwise , actual profit is compared with the 300-day estimate . " \ | ||
"Source : Zacks Investment Research" | ||
text2 = "The companies are followed by at least three analysts , and had a minimum five-cent change in actual " \ | ||
"earnings per share . The percent difference compares actual profit with the 30-day estimate where at least " \ | ||
"three analysts have issues forecasts in the past 30 days . Otherwise , actual profit is compared with the " \ | ||
"300-day estimate . Source : Zacks Investment Research. Companies listed below reported quarterly profit " \ | ||
"substantially different from the average of analysts ' estimates . Estimated and actual results involving " \ | ||
"losses are omitted ." | ||
|
||
text1_tensor = preprocessor([text1]) | ||
text2_tensor = preprocessor([text2]) | ||
|
||
text1_score = model.get_main_score(text1_tensor["tokenized_texts"]).item() | ||
text2_score = model.get_main_score(text2_tensor["tokenized_texts"]).item() | ||
|
||
print(text1_score, text2_score) |