Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial commit for Text-sentiment-classifier #8

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 121 additions & 0 deletions bert-text-classifier/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# MAX for TensorFlow.js: Text Sentiment Classifier

This is a TensorFlow.js port of the [MAX Text Sentiment Classifier](https://developer.ibm.com/exchanges/models/all/max-text-sentiment-classifier/) This model is able to detect whether a text fragment leans towards a positive or a negative sentiment.

## Install

### Browser

```html
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<script src="https://cdn.jsdelivr.net/npm/@codait/text-sentiment-classifier"></script>
```

### Node.js

```
npm install --save @codait/text-sentiment-classifier
```

## Usage

The complete examples for browser and Node.js environments are in the [`/examples`](https://github.com/CODAIT/max-tfjs-models/tree/master/text-sentiment-classifier/examples) directory.

### Browser

> **Note**: _When loaded in a browser, the global variable `textSentimentClassifier` will be available to access the API._

```javascript

textSentimentClassifier
.predict("i like strawberries")
.then(prediction => {
console.log(prediction)
});
```

### Node.js

```javascript
const tc = require('@codait/text-sentiment-classifier');
tc.predict("i like strawberries").then(res=>console.log(res)); //{ pos: 0.9981953501701355, neg: 0.0018045296892523766 }

```

### API

- **loadModel()**

Loads the model files.

Running in Node.js the first time will download the model assets locally under `/model` directory. The subsequent calls will load the model from the directory.

Returns the TensorFlow.js model.

- **processInput(text)**

Processes the input text to the shape and format expected by the model.

`text` - sentence to be processed. It should be a sentence with a period although this is not necessary.

Returns a named tensor map that contains:
`{'segment_ids_1': Tensor of shape [128],
'input_ids_1': Tensor of shape [128],
'input_mask_1': Tensor of shape [128]}`

- **runInference(inputFeatures)**

Runs inference on the named tensor map passed. The output is a tensor that contains softmax of positive and negative percentages.

`inputFeature` - a named tensor map representation of a text.

Returns the inference results as a 1D tensor.

- **processOutput(tensor)**

Transform the inference output to a Json object.

`tensor` - the model output from running inference.

Returns an object containing: `{neg: number, pos: number}`


- **predict(text)**

Loads the model, processes the input text, runs inference, processes the inference output, and returns a prediction object. This is a convenience function to avoid having to call each of the functions (`loadModel`, `processInput`, `runInference`, `processOutput`) individually.

`text` - sentence to be analyzed. It should be a sentence with a period although this is not necessary.

Returns an object containing: `{neg: number, pos: number}`

- **encode(text)**

Tokenize the text as token ids using the BERT 32k vocabularies.

`text` - sentence to be encoded.

Returns an array of BERT token ids.

- **idsToTokens(ids)**

Transform the BERT token ids into tokens.

`ids` - BERT token ids.

Returns an array of BERT tokens.

- **version**

Returns the version

## Model

The model assets produced by converting the pre-trained model to the TensorFlow.js format can be found in the `/model` directory after loadModel is called in Node.js.

## Resources

- [MAX Text Sentiment Classifier](https://developer.ibm.com/exchanges/models/all/max-text-sentiment-classifier/)

## License

[Apache-2.0](https://github.com/CODAIT/max-tfjs-models/blob/master/LICENSE)
69 changes: 69 additions & 0 deletions bert-text-classifier/examples/test.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
<meta charset="utf-8"/>
<html>
<head>
<title>text classifier</title>
</head>
<body onload=loadModel()>
<h1>Text Sentiment Classifier</h1>
<div id="div1">
<input type="text" id="text_input" name="text_input" value="Enter a sentence.">
<input type="button" id="submit_text" value="Analyze">
</div>
<p>
<label for="status">Status:</label>
<label name="status" id="status">Loading Model...</label>
</p>
<p>
<label for="status">Result:</label>
<label name="result" id="result"></label>
</p>
</body>
<script src="../dist/src/max.sentimentclass.js"></script>
<script>
let submitButton = document.getElementById("submit_text");
let statusElement = document.getElementById("status");
let resultElement = document.getElementById("result");
submitButton.addEventListener("click", runPredict);
function enableAnalyzeButton(){
submitButton.removeAttribute('disabled');
}

function disableAnalyzeButton(){
submitButton.setAttribute('disabled','true');
}

function loadModel(){
submitButton.setAttribute('disabled','true');
textSentimentClassifier.loadModel().then( () =>{
submitButton.removeAttribute('disabled');
updateStatus("Model Loaded");
})
}
function updateStatus(msg) {
statusElement.innerHTML = msg;
}
function appendResult(msg) {
resultElement.innerHTML += msg;
}
function runPredict(){
updateStatus("Running prediction...");
disableAnalyzeButton();
const text = document.getElementById("text_input").value;
setTimeout(() => { //yield the thread to update UI
textSentimentClassifier.predict(text)
.then((res) =>{
appendResult(`
<br> ${text}
<br>Positive &#128512;: ${res['pos'].toFixed(4)}
<br>Negative &#128534;: ${res['neg'].toFixed(4)}<br>`);
updateStatus("Ok");
enableAnalyzeButton();
})
.catch((err) => {
updateStatus(err);
enableAnalyzeButton();
})
}, 20);
}
</script>
</html>
3 changes: 3 additions & 0 deletions bert-text-classifier/examples/test.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
const tc = require("../dist/src/max.sentimentclass.cjs.js");
tc.predict("i like strawberries").then(res=>console.log(res));
tc.encode("i like strawberries").then(res=>console.log(res));
124 changes: 124 additions & 0 deletions bert-text-classifier/model/conversion_manifest.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
{
"name": "Max-Text-Sentiment-Classifier",
"url": "https://developer.ibm.com/exchanges/models/all/max-text-sentiment-classifier/",
"source": "max-text-sentiment-classifier/1.0/assets.tar.gz/sentiment_BERT_base_uncased/",
"framework": "tf_js",
"converter": {
"tensorflowjs": {
"version": "1.2.9",
"params": {
"output_node_names": "loss/Softmax",
"input_format": "tf_saved_model",
"output_json": true
}
}
},
"output": [
"model.json",
"group1-shard1of105.bin",
"group1-shard2of105.bin",
"group1-shard3of105.bin",
"group1-shard4of105.bin",
"group1-shard5of105.bin",
"group1-shard6of105.bin",
"group1-shard7of105.bin",
"group1-shard8of105.bin",
"group1-shard9of105.bin",
"group1-shard10of105.bin",
"group1-shard11of105.bin",
"group1-shard12of105.bin",
"group1-shard13of105.bin",
"group1-shard14of105.bin",
"group1-shard15of105.bin",
"group1-shard16of105.bin",
"group1-shard17of105.bin",
"group1-shard18of105.bin",
"group1-shard19of105.bin",
"group1-shard20of105.bin",
"group1-shard21of105.bin",
"group1-shard22of105.bin",
"group1-shard23of105.bin",
"group1-shard24of105.bin",
"group1-shard25of105.bin",
"group1-shard26of105.bin",
"group1-shard27of105.bin",
"group1-shard28of105.bin",
"group1-shard29of105.bin",
"group1-shard30of105.bin",
"group1-shard31of105.bin",
"group1-shard32of105.bin",
"group1-shard33of105.bin",
"group1-shard34of105.bin",
"group1-shard35of105.bin",
"group1-shard36of105.bin",
"group1-shard37of105.bin",
"group1-shard38of105.bin",
"group1-shard39of105.bin",
"group1-shard40of105.bin",
"group1-shard41of105.bin",
"group1-shard42of105.bin",
"group1-shard43of105.bin",
"group1-shard44of105.bin",
"group1-shard45of105.bin",
"group1-shard46of105.bin",
"group1-shard47of105.bin",
"group1-shard48of105.bin",
"group1-shard49of105.bin",
"group1-shard50of105.bin",
"group1-shard51of105.bin",
"group1-shard52of105.bin",
"group1-shard53of105.bin",
"group1-shard54of105.bin",
"group1-shard55of105.bin",
"group1-shard56of105.bin",
"group1-shard57of105.bin",
"group1-shard58of105.bin",
"group1-shard59of105.bin",
"group1-shard60of105.bin",
"group1-shard61of105.bin",
"group1-shard62of105.bin",
"group1-shard63of105.bin",
"group1-shard64of105.bin",
"group1-shard65of105.bin",
"group1-shard66of105.bin",
"group1-shard67of105.bin",
"group1-shard68of105.bin",
"group1-shard69of105.bin",
"group1-shard70of105.bin",
"group1-shard71of105.bin",
"group1-shard72of105.bin",
"group1-shard73of105.bin",
"group1-shard74of105.bin",
"group1-shard75of105.bin",
"group1-shard76of105.bin",
"group1-shard77of105.bin",
"group1-shard78of105.bin",
"group1-shard79of105.bin",
"group1-shard80of105.bin",
"group1-shard81of105.bin",
"group1-shard82of105.bin",
"group1-shard83of105.bin",
"group1-shard84of105.bin",
"group1-shard85of105.bin",
"group1-shard86of105.bin",
"group1-shard87of105.bin",
"group1-shard88of105.bin",
"group1-shard89of105.bin",
"group1-shard90of105.bin",
"group1-shard91of105.bin",
"group1-shard92of105.bin",
"group1-shard93of105.bin",
"group1-shard94of105.bin",
"group1-shard95of105.bin",
"group1-shard96of105.bin",
"group1-shard97of105.bin",
"group1-shard98of105.bin",
"group1-shard99of105.bin",
"group1-shard100of105.bin",
"group1-shard101of105.bin",
"group1-shard102of105.bin",
"group1-shard103of105.bin",
"group1-shard104of105.bin",
"group1-shard105of105.bin"
]
}
Loading