This project aims to classify tweets related to disasters using a machine learning model based on the DistilBERT architecture. The primary objective is to determine whether a given tweet is about a real disaster (target=1) or not (target=0).
- Project Overview
- Dataset
- Installation
- Training the Model
- Evaluation
- Prediction
- Gradio UI
- Results
- Contributing
- License
This project uses the DistilBERT model from Hugging Face's Transformers library for sequence classification. The steps involved include data preprocessing, tokenization, model training, evaluation, and prediction. Additionally, a Gradio UI is provided for testing the model interactively.
The dataset consists of tweets labeled as either related to disasters or not. The dataset is split into training, validation, and test sets. The training set is used to train the model, the validation set is used for hyperparameter tuning, and the test set is used for evaluating the model's performance.
To install the required dependencies, run the following commands:
git clone https://github.com/utsav-desai/CS772.git
pip install accelerate -U datasets
pip install transformers[torch]
pip install spacy_cleaner
pip install gradio
To train the DistilBERT model, follow these steps:
- Preprocess the data using the
spacy_cleaner
library. - Tokenize the text data using the
AutoTokenizer
from Hugging Face. - Train the DistilBERT model using the
Trainer
andTrainingArguments
classes.
The model's performance is evaluated using accuracy and F1-score metrics. The evaluation is performed on the validation set.
To make predictions on new data, the trained model is used to classify tweets. The model outputs a label indicating whether the tweet is related to a disaster or not.
A Gradio UI is provided for testing the model interactively. The UI allows users to input a text sentence and get the model's prediction.
To launch the Gradio UI, run the following code:
import gradio as gr
def classify_text(text):
# Preprocess the text
text = text.lower()
text = re.sub("[#=><\/.]", "", text)
text = re.sub("@\w+", "", text)
# Tokenize the text
tokenized_text = tokenizer(text)
# Convert the tokenized text to a tensor
input_ids = torch.tensor(tokenized_text["input_ids"]).unsqueeze(0)
attention_mask = torch.tensor(tokenized_text["attention_mask"]).unsqueeze(0)
# Load the model
model = AutoModelForSequenceClassification.from_pretrained("my_model_weights")
# Make predictions
outputs = model(input_ids, attention_mask=attention_mask)
predictions = torch.argmax(outputs.logits, dim=-1)
# Return the predictions
return predictions.item()
demo = gr.Interface(
fn=classify_text,
inputs=gr.Textbox(label="Enter a text sentence here"),
outputs="label",
examples=[
"This is a disaster!",
"Earthquake is expected in China!",
"I'm feeling happy.",
],
)
demo.launch()
The model achieved an accuracy of 98.68% and an F1-score of 98.48% on the validation set. The performance metrics indicate that the model is effective in classifying tweets related to disasters