Skip to content

Commit

Permalink
[#57] Add files for demo api
Browse files Browse the repository at this point in the history
  • Loading branch information
jonheng committed Sep 19, 2022
1 parent a54c7ed commit 0408670
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 0 deletions.
46 changes: 46 additions & 0 deletions demo_api/coherence_momentum/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from flask import request

from demo_api.common import create_api
from sgnlp.models.coherence_momentum import (
CoherenceMomentumModel,
CoherenceMomentumConfig,
CoherenceMomentumPreprocessor
)

app = create_api(app_name=__name__, model_card_path="model_card/coherence_momentum.json")

# Load processors and models
config = CoherenceMomentumConfig.from_pretrained(
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json"
)
model = CoherenceMomentumModel.from_pretrained(
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin",
config=config
)

preprocessor = CoherenceMomentumPreprocessor(config.model_size, config.max_len)

app.logger.info("Model initialization complete")


@app.route("/predict", methods=["POST"])
def predict():
req_body = request.get_json()

text1 = req_body["text1"]
text2 = req_body["text2"]

text1_tensor = preprocessor([text1])
text2_tensor = preprocessor([text2])

text1_score = model.get_main_score(text1_tensor["tokenized_texts"]).item()
text2_score = model.get_main_score(text2_tensor["tokenized_texts"]).item()

return {
"text1_score": text1_score,
"text2_score": text2_score
}


if __name__ == "__main__":
app.run()
14 changes: 14 additions & 0 deletions demo_api/coherence_momentum/dev.Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
FROM python:3.8-buster

COPY ./demo_api /demo_api
COPY ./sgnlp /sgnlp
COPY ./setup.py /setup.py
COPY ./README.md /README.md

RUN pip install -r /demo_api/coherence_momentum/requirements_dev.txt

WORKDIR /demo_api/coherence_momentum

RUN python -m download_pretrained

CMD PYTHONPATH=../../ gunicorn -c ../gunicorn.conf.py
9 changes: 9 additions & 0 deletions demo_api/coherence_momentum/download_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from sgnlp.models.coherence_momentum import CoherenceMomentumModel, CoherenceMomentumConfig

config = CoherenceMomentumConfig.from_pretrained(
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json"
)
model = CoherenceMomentumModel.from_pretrained(
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin",
config=config
)
36 changes: 36 additions & 0 deletions demo_api/coherence_momentum/model_card/coherence_momentum.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
{
"name": "CoherenceMomentum",
"languages": "English",
"description": "This is a neural network model that makes use of a momentum encoder and hard negative mining during training. This model is able to take in a piece of text and output a coherence score. The coherence score is only meant for comparison, i.e. it is only meaningful when used to compare between two texts, and the text with the higher coherence score is deemed to be more coherent by the model.",
"paper": {
"text": "Jwalapuram, P., Joty, S., & Lin, X. (2022). Rethinking Self-Supervision Objectives for Generalizable Coherence Modeling. Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), May 2022 (pp. 6044-6059).",
"url": "https://aclanthology.org/2022.acl-long.418/"
},
"trainingDataset": "Permuted dataset derived from Linguistic Data Consortium's (LDC) Wall Street Journal (WSJ) dataset. Please contact the authors to get the dataset if you have a valid LDC license.",
"evaluationDataset": "Permuted dataset derived from Linguistic Data Consortium's (LDC) Wall Street Journal (WSJ) dataset. Please contact the authors to get the dataset if you have a valid LDC license.",
"evaluationScores": "0.988 accuracy on permuted WSJ dataset. 0.986 accuracy reported by authors on permuted WSJ dataset.",
"trainingConfig": {
"text": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json",
"url": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json"
},
"trainingTime": "~24 hours for ~46000 steps (batch size of 1) on a single A100 GPU",
"modelWeights": {
"text": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin",
"url": "https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin"
},
"modelInput": "A paragraph of text. During training, each positive example can be paired with one or more negative examples.",
"modelOutput": "Coherence score for the input text.",
"modelSize": "~930MB",
"inferenceInfo": "Not available.",
"usageScenarios": "Essay scoring, summarization, language generation.",
"originalCode": {
"text": "https://github.com/ntunlp/coherence-paradigm",
"url": "https://github.com/ntunlp/coherence-paradigm"
},
"license": {
"text": "MIT License",
"url": "https://choosealicense.com/licenses/mit"
},
"contact": "[email protected]",
"additionalInfo": "Not applicable."
}
3 changes: 3 additions & 0 deletions demo_api/coherence_momentum/requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
-e.
flask
gunicorn
33 changes: 33 additions & 0 deletions demo_api/coherence_momentum/usage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from sgnlp.models.coherence_momentum import CoherenceMomentumModel, CoherenceMomentumConfig, \
CoherenceMomentumPreprocessor

config = CoherenceMomentumConfig.from_pretrained(
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/config.json"
)
model = CoherenceMomentumModel.from_pretrained(
"https://storage.googleapis.com/sgnlp/models/coherence_momentum/pytorch_model.bin",
config=config
)

preprocessor = CoherenceMomentumPreprocessor(config.model_size, config.max_len)

text1 = "Companies listed below reported quarterly profit substantially different from the average of analysts ' " \
"estimates . The companies are followed by at least three analysts , and had a minimum five-cent change in " \
"actual earnings per share . Estimated and actual results involving losses are omitted . The percent " \
"difference compares actual profit with the 30-day estimate where at least three analysts have issues " \
"forecasts in the past 30 days . Otherwise , actual profit is compared with the 300-day estimate . " \
"Source : Zacks Investment Research"
text2 = "The companies are followed by at least three analysts , and had a minimum five-cent change in actual " \
"earnings per share . The percent difference compares actual profit with the 30-day estimate where at least " \
"three analysts have issues forecasts in the past 30 days . Otherwise , actual profit is compared with the " \
"300-day estimate . Source : Zacks Investment Research. Companies listed below reported quarterly profit " \
"substantially different from the average of analysts ' estimates . Estimated and actual results involving " \
"losses are omitted ."

text1_tensor = preprocessor([text1])
text2_tensor = preprocessor([text2])

text1_score = model.get_main_score(text1_tensor["tokenized_texts"]).item()
text2_score = model.get_main_score(text2_tensor["tokenized_texts"]).item()

print(text1_score, text2_score)

0 comments on commit 0408670

Please sign in to comment.