diff --git a/.gitignore b/.gitignore index f15b601..1286603 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ # Datasets +pytorch_ipynb/rnn/yelp_review_polarity_csv/ +pytorch_ipynb/rnn/ag_news_csv/ +pytorch_ipynb/rnn/amazon_review_polarity_csv/ HistoricalColor-ECCV2012* AFAD-Lite tarball* diff --git a/README.md b/README.md index ac4a008..f90bbdc 100644 --- a/README.md +++ b/README.md @@ -213,6 +213,13 @@ A collection of various deep learning architectures, models, and tips for Tensor    [PyTorch: [GitHub](pytorch_ipynb/rnn/rnn_gru_packed_imdb.ipynb) | [Nbviewer](https://nbviewer.jupyter.org/github/rasbt/deeplearning-models/blob/master/pytorch_ipynb/rnn/rnn_gru_packed_imdb.ipynb)] - Multilayer bi-directional RNN (IMDB)    [PyTorch: [GitHub](pytorch_ipynb/rnn/rnn_gru_packed_imdb.ipynb) | [Nbviewer](https://nbviewer.jupyter.org/github/rasbt/deeplearning-models/blob/master/pytorch_ipynb/rnn/rnn_gru_packed_imdb.ipynb)] +- Bidirectional Multi-layer RNN with LSTM with Own Dataset in CSV Format (AG News) +   [PyTorch: [GitHub](pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_agnews.ipynb) | [Nbviewer](https://nbviewer.jupyter.org/github/rasbt/deeplearning-models/blob/master/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_agnews.ipynb)] +- Bidirectional Multi-layer RNN with LSTM with Own Dataset in CSV Format (Yelp Review Polarity) +   [PyTorch: [GitHub](pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_yelp-polarity.ipynb) | [Nbviewer](https://nbviewer.jupyter.org/github/rasbt/deeplearning-models/blob/master/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_yelp-polarity.ipynb)] +- Bidirectional Multi-layer RNN with LSTM with Own Dataset in CSV Format (Amazon Review Polarity) +   [PyTorch: [GitHub](pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_amazon-polarity.ipynb) | [Nbviewer](https://nbviewer.jupyter.org/github/rasbt/deeplearning-models/blob/master/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_amazon-polarity.ipynb)] + #### Many-to-Many / Sequence-to-Sequence diff --git a/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_agnews.ipynb b/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_agnews.ipynb new file mode 100644 index 0000000..8f7d0b3 --- /dev/null +++ b/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_agnews.ipynb @@ -0,0 +1,2152 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", + "- Author: Sebastian Raschka\n", + "- GitHub Repository: https://github.com/rasbt/deeplearning-models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vY4SK0xKAJgm" + }, + "source": [ + "# Bidirectional Multi-layer RNN with LSTM with Own Dataset in CSV Format (AG News)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Dataset Description\n", + "\n", + "```\n", + "AG's News Topic Classification Dataset\n", + "\n", + "Version 3, Updated 09/09/2015\n", + "\n", + "\n", + "ORIGIN\n", + "\n", + "AG is a collection of more than 1 million news articles. News articles have been gathered from more than 2000 news sources by ComeToMyHead in more than 1 year of activity. ComeToMyHead is an academic news search engine which has been running since July, 2004. The dataset is provided by the academic community for research purposes in data mining (clustering, classification, etc), information retrieval (ranking, search, etc), xml, data compression, data streaming, and any other non-commercial activity. For more information, please refer to the link http://www.di.unipi.it/~gulli/AG_corpus_of_news_articles.html .\n", + "\n", + "The AG's news topic classification dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the dataset above. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015).\n", + "\n", + "\n", + "DESCRIPTION\n", + "\n", + "The AG's news topic classification dataset is constructed by choosing 4 largest classes from the original corpus. Each class contains 30,000 training samples and 1,900 testing samples. The total number of training samples is 120,000 and testing 7,600.\n", + "\n", + "The file classes.txt contains a list of classes corresponding to each label.\n", + "\n", + "The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 to 4), title and description. The title and description are escaped using double quotes (\"), and any internal double quote is escaped by 2 double quotes (\"\"). New lines are escaped by a backslash followed with an \"n\" character, that is \"\\n\".\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "moNmVfuvnImW" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sebastian Raschka \n", + "\n", + "CPython 3.7.3\n", + "IPython 7.9.0\n", + "\n", + "torch 1.3.0\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%watermark -a 'Sebastian Raschka' -v -p torch\n", + "\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torchtext import data\n", + "from torchtext import datasets\n", + "import time\n", + "import random\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "torch.backends.cudnn.deterministic = True" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GSRL42Qgy8I8" + }, + "source": [ + "## General Settings" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "OvW1RgfepCBq" + }, + "outputs": [], + "source": [ + "RANDOM_SEED = 123\n", + "torch.manual_seed(RANDOM_SEED)\n", + "\n", + "VOCABULARY_SIZE = 5000\n", + "LEARNING_RATE = 1e-3\n", + "BATCH_SIZE = 128\n", + "NUM_EPOCHS = 50\n", + "DROPOUT = 0.5\n", + "DEVICE = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "EMBEDDING_DIM = 128\n", + "BIDIRECTIONAL = True\n", + "HIDDEN_DIM = 256\n", + "NUM_LAYERS = 2\n", + "OUTPUT_DIM = 4" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "mQMmKUEisW4W" + }, + "source": [ + "## Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The AG News dataset is available from Xiang Zhang's Google Drive folder at\n", + "\n", + "https://drive.google.com/drive/u/0/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M\n", + "\n", + "From the Google Drive folder, download the file \n", + "\n", + "- `ag_news_csv.tar.gz`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ag_news_csv/\n", + "ag_news_csv/train.csv\n", + "ag_news_csv/test.csv\n", + "ag_news_csv/classes.txt\n", + "ag_news_csv/readme.txt\n" + ] + } + ], + "source": [ + "!tar xvzf ag_news_csv.tar.gz" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "World\n", + "Sports\n", + "Business\n", + "Sci/Tech\n" + ] + } + ], + "source": [ + "!cat ag_news_csv/classes.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check that the dataset looks okay:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
classlabeltitlecontent
02Wall St. Bears Claw Back Into the Black (Reuters)Reuters - Short-sellers, Wall Street's dwindli...
12Carlyle Looks Toward Commercial Aerospace (Reu...Reuters - Private investment firm Carlyle Grou...
22Oil and Economy Cloud Stocks' Outlook (Reuters)Reuters - Soaring crude prices plus worries\\ab...
32Iraq Halts Oil Exports from Main Southern Pipe...Reuters - Authorities have halted oil export\\f...
42Oil prices soar to all-time record, posing new...AFP - Tearaway world oil prices, toppling reco...
\n", + "
" + ], + "text/plain": [ + " classlabel title \\\n", + "0 2 Wall St. Bears Claw Back Into the Black (Reuters) \n", + "1 2 Carlyle Looks Toward Commercial Aerospace (Reu... \n", + "2 2 Oil and Economy Cloud Stocks' Outlook (Reuters) \n", + "3 2 Iraq Halts Oil Exports from Main Southern Pipe... \n", + "4 2 Oil prices soar to all-time record, posing new... \n", + "\n", + " content \n", + "0 Reuters - Short-sellers, Wall Street's dwindli... \n", + "1 Reuters - Private investment firm Carlyle Grou... \n", + "2 Reuters - Soaring crude prices plus worries\\ab... \n", + "3 Reuters - Authorities have halted oil export\\f... \n", + "4 AFP - Tearaway world oil prices, toppling reco... " + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('ag_news_csv/train.csv', header=None, index_col=None)\n", + "df.columns = ['classlabel', 'title', 'content']\n", + "df['classlabel'] = df['classlabel']-1\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1, 2, 3])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(df['classlabel'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([30000, 30000, 30000, 30000])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.bincount(df['classlabel'])" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "df[['classlabel', 'content']].to_csv('ag_news_csv/train_prepocessed.csv', index=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
classlabeltitlecontent
02Fears for T N pension after talksUnions representing workers at Turner Newall...
13The Race is On: Second Private Team Sets Launc...SPACE.com - TORONTO, Canada -- A second\\team o...
23Ky. Company Wins Grant to Study Peptides (AP)AP - A company founded by a chemistry research...
33Prediction Unit Helps Forecast Wildfires (AP)AP - It's barely dawn when Mike Fitzpatrick st...
43Calif. Aims to Limit Farm-Related Smog (AP)AP - Southern California's smog-fighting agenc...
\n", + "
" + ], + "text/plain": [ + " classlabel title \\\n", + "0 2 Fears for T N pension after talks \n", + "1 3 The Race is On: Second Private Team Sets Launc... \n", + "2 3 Ky. Company Wins Grant to Study Peptides (AP) \n", + "3 3 Prediction Unit Helps Forecast Wildfires (AP) \n", + "4 3 Calif. Aims to Limit Farm-Related Smog (AP) \n", + "\n", + " content \n", + "0 Unions representing workers at Turner Newall... \n", + "1 SPACE.com - TORONTO, Canada -- A second\\team o... \n", + "2 AP - A company founded by a chemistry research... \n", + "3 AP - It's barely dawn when Mike Fitzpatrick st... \n", + "4 AP - Southern California's smog-fighting agenc... " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('ag_news_csv/test.csv', header=None, index_col=None)\n", + "df.columns = ['classlabel', 'title', 'content']\n", + "df['classlabel'] = df['classlabel']-1\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1, 2, 3])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(df['classlabel'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1900, 1900, 1900, 1900])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.bincount(df['classlabel'])" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "df[['classlabel', 'content']].to_csv('ag_news_csv/test_prepocessed.csv', index=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "del df" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4GnH64XvsV8n" + }, + "source": [ + "Define the Label and Text field formatters:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "TEXT = data.Field(sequential=True,\n", + " tokenize='spacy',\n", + " include_lengths=True) # necessary for packed_padded_sequence\n", + "\n", + "LABEL = data.LabelField(dtype=torch.float)\n", + "\n", + "\n", + "# If you get an error [E050] Can't find model 'en'\n", + "# you need to run the following on your command line:\n", + "# python -m spacy download en" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Process the dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "fields = [('classlabel', LABEL), ('content', TEXT)]\n", + "\n", + "train_dataset = data.TabularDataset(\n", + " path=\"ag_news_csv/train_prepocessed.csv\", format='csv',\n", + " skip_header=True, fields=fields)\n", + "\n", + "test_dataset = data.TabularDataset(\n", + " path=\"ag_news_csv/test_prepocessed.csv\", format='csv',\n", + " skip_header=True, fields=fields)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Split the training dataset into training and validation:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "colab_type": "code", + "id": "WZ_4jiHVnMxN", + "outputId": "dfa51c04-4845-44c3-f50b-d36d41f132b8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num Train: 114000\n", + "Num Valid: 6000\n" + ] + } + ], + "source": [ + "train_data, valid_data = train_dataset.split(\n", + " split_ratio=[0.95, 0.05],\n", + " random_state=random.seed(RANDOM_SEED))\n", + "\n", + "print(f'Num Train: {len(train_data)}')\n", + "print(f'Num Valid: {len(valid_data)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "L-TBwKWPslPa" + }, + "source": [ + "Build the vocabulary based on the top \"VOCABULARY_SIZE\" words:" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "colab_type": "code", + "id": "e8uNrjdtn4A8", + "outputId": "6cf499d7-7722-4da0-8576-ee0f218cc6e3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vocabulary size: 5002\n", + "Number of classes: 4\n" + ] + } + ], + "source": [ + "TEXT.build_vocab(train_data,\n", + " max_size=VOCABULARY_SIZE,\n", + " vectors='glove.6B.100d',\n", + " unk_init=torch.Tensor.normal_)\n", + "\n", + "LABEL.build_vocab(train_data)\n", + "\n", + "print(f'Vocabulary size: {len(TEXT.vocab)}')\n", + "print(f'Number of classes: {len(LABEL.vocab)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['1', '3', '0', '2']" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(LABEL.vocab.freqs)[-10:]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JpEMNInXtZsb" + }, + "source": [ + "The TEXT.vocab dictionary will contain the word counts and indices. The reason why the number of words is VOCABULARY_SIZE + 2 is that it contains to special tokens for padding and unknown words: `` and ``." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eIQ_zfKLwjKm" + }, + "source": [ + "Make dataset iterators:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "i7JiHR1stHNF" + }, + "outputs": [], + "source": [ + "train_loader, valid_loader, test_loader = data.BucketIterator.splits(\n", + " (train_data, valid_data, test_dataset), \n", + " batch_size=BATCH_SIZE,\n", + " sort_within_batch=True, # necessary for packed_padded_sequence\n", + " sort_key=lambda x: len(x.content),\n", + " device=DEVICE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "R0pT_dMRvicQ" + }, + "source": [ + "Testing the iterators (note that the number of rows depends on the longest document in the respective batch):" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + }, + "colab_type": "code", + "id": "y8SP_FccutT0", + "outputId": "fe33763a-4560-4dee-adee-31cc6c48b0b2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train\n", + "Text matrix size: torch.Size([35, 128])\n", + "Target vector size: torch.Size([128])\n", + "\n", + "Valid:\n", + "Text matrix size: torch.Size([17, 128])\n", + "Target vector size: torch.Size([128])\n", + "\n", + "Test:\n", + "Text matrix size: torch.Size([16, 128])\n", + "Target vector size: torch.Size([128])\n" + ] + } + ], + "source": [ + "print('Train')\n", + "for batch in train_loader:\n", + " print(f'Text matrix size: {batch.content[0].size()}')\n", + " print(f'Target vector size: {batch.classlabel.size()}')\n", + " break\n", + " \n", + "print('\\nValid:')\n", + "for batch in valid_loader:\n", + " print(f'Text matrix size: {batch.content[0].size()}')\n", + " print(f'Target vector size: {batch.classlabel.size()}')\n", + " break\n", + " \n", + "print('\\nTest:')\n", + "for batch in test_loader:\n", + " print(f'Text matrix size: {batch.content[0].size()}')\n", + " print(f'Target vector size: {batch.classlabel.size()}')\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "G_grdW3pxCzz" + }, + "source": [ + "## Model" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "nQIUm5EjxFNa" + }, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "\n", + "class RNN(nn.Module):\n", + " def __init__(self, input_dim, embedding_dim, bidirectional, hidden_dim, num_layers, output_dim, dropout, pad_idx):\n", + " \n", + " super().__init__()\n", + " \n", + " self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx=pad_idx)\n", + " self.rnn = nn.LSTM(embedding_dim, \n", + " hidden_dim,\n", + " num_layers=num_layers,\n", + " bidirectional=bidirectional, \n", + " dropout=dropout)\n", + " self.fc1 = nn.Linear(hidden_dim * num_layers, 64)\n", + " self.fc2 = nn.Linear(64, output_dim)\n", + " self.dropout = nn.Dropout(dropout)\n", + " \n", + " def forward(self, text, text_length):\n", + "\n", + " embedded = self.dropout(self.embedding(text))\n", + " packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_length)\n", + " packed_output, (hidden, cell) = self.rnn(packed_embedded)\n", + " output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)\n", + " hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))\n", + " hidden = self.fc1(hidden)\n", + " return hidden" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ik3NF3faxFmZ" + }, + "outputs": [], + "source": [ + "INPUT_DIM = len(TEXT.vocab)\n", + "\n", + "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n", + "\n", + "torch.manual_seed(RANDOM_SEED)\n", + "model = RNN(INPUT_DIM, EMBEDDING_DIM, BIDIRECTIONAL, HIDDEN_DIM, NUM_LAYERS, OUTPUT_DIM, DROPOUT, PAD_IDX)\n", + "model = model.to(DEVICE)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Lv9Ny9di6VcI" + }, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "T5t1Afn4xO11" + }, + "outputs": [], + "source": [ + "def compute_accuracy(model, data_loader, device):\n", + " model.eval()\n", + " correct_pred, num_examples = 0, 0\n", + " with torch.no_grad():\n", + " for batch_idx, batch_data in enumerate(data_loader):\n", + " text, text_lengths = batch_data.content\n", + " logits = model(text, text_lengths).squeeze(1)\n", + " _, predicted_labels = torch.max(logits, 1)\n", + " num_examples += batch_data.classlabel.size(0)\n", + " correct_pred += (predicted_labels.long() == batch_data.classlabel.long()).sum()\n", + " return correct_pred.float()/num_examples * 100" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1836 + }, + "colab_type": "code", + "id": "EABZM8Vo0ilB", + "outputId": "5d45e293-9909-4588-e793-8dfaf72e5c67" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001/050 | Batch 000/891 | Cost: 4.1667\n", + "Epoch: 001/050 | Batch 050/891 | Cost: 1.4755\n", + "Epoch: 001/050 | Batch 100/891 | Cost: 1.3285\n", + "Epoch: 001/050 | Batch 150/891 | Cost: 1.2829\n", + "Epoch: 001/050 | Batch 200/891 | Cost: 1.1988\n", + "Epoch: 001/050 | Batch 250/891 | Cost: 0.7590\n", + "Epoch: 001/050 | Batch 300/891 | Cost: 0.9044\n", + "Epoch: 001/050 | Batch 350/891 | Cost: 0.8458\n", + "Epoch: 001/050 | Batch 400/891 | Cost: 0.6982\n", + "Epoch: 001/050 | Batch 450/891 | Cost: 0.5888\n", + "Epoch: 001/050 | Batch 500/891 | Cost: 0.5934\n", + "Epoch: 001/050 | Batch 550/891 | Cost: 0.6286\n", + "Epoch: 001/050 | Batch 600/891 | Cost: 0.5483\n", + "Epoch: 001/050 | Batch 650/891 | Cost: 0.6919\n", + "Epoch: 001/050 | Batch 700/891 | Cost: 0.3992\n", + "Epoch: 001/050 | Batch 750/891 | Cost: 0.8836\n", + "Epoch: 001/050 | Batch 800/891 | Cost: 0.5266\n", + "Epoch: 001/050 | Batch 850/891 | Cost: 0.5314\n", + "training accuracy: 84.63%\n", + "valid accuracy: 83.77%\n", + "Time elapsed: 0.46 min\n", + "Epoch: 002/050 | Batch 000/891 | Cost: 0.3907\n", + "Epoch: 002/050 | Batch 050/891 | Cost: 0.3792\n", + "Epoch: 002/050 | Batch 100/891 | Cost: 0.4874\n", + "Epoch: 002/050 | Batch 150/891 | Cost: 0.5823\n", + "Epoch: 002/050 | Batch 200/891 | Cost: 0.4622\n", + "Epoch: 002/050 | Batch 250/891 | Cost: 0.3818\n", + "Epoch: 002/050 | Batch 300/891 | Cost: 0.4743\n", + "Epoch: 002/050 | Batch 350/891 | Cost: 0.5085\n", + "Epoch: 002/050 | Batch 400/891 | Cost: 0.4229\n", + "Epoch: 002/050 | Batch 450/891 | Cost: 0.3666\n", + "Epoch: 002/050 | Batch 500/891 | Cost: 0.3102\n", + "Epoch: 002/050 | Batch 550/891 | Cost: 0.4300\n", + "Epoch: 002/050 | Batch 600/891 | Cost: 0.6906\n", + "Epoch: 002/050 | Batch 650/891 | Cost: 0.3315\n", + "Epoch: 002/050 | Batch 700/891 | Cost: 0.4410\n", + "Epoch: 002/050 | Batch 750/891 | Cost: 0.3719\n", + "Epoch: 002/050 | Batch 800/891 | Cost: 0.4229\n", + "Epoch: 002/050 | Batch 850/891 | Cost: 0.5765\n", + "training accuracy: 89.51%\n", + "valid accuracy: 88.93%\n", + "Time elapsed: 0.93 min\n", + "Epoch: 003/050 | Batch 000/891 | Cost: 0.4050\n", + "Epoch: 003/050 | Batch 050/891 | Cost: 0.3719\n", + "Epoch: 003/050 | Batch 100/891 | Cost: 0.3914\n", + "Epoch: 003/050 | Batch 150/891 | Cost: 0.2547\n", + "Epoch: 003/050 | Batch 200/891 | Cost: 0.2478\n", + "Epoch: 003/050 | Batch 250/891 | Cost: 0.6579\n", + "Epoch: 003/050 | Batch 300/891 | Cost: 0.3390\n", + "Epoch: 003/050 | Batch 350/891 | Cost: 0.4368\n", + "Epoch: 003/050 | Batch 400/891 | Cost: 0.3960\n", + "Epoch: 003/050 | Batch 450/891 | Cost: 0.2799\n", + "Epoch: 003/050 | Batch 500/891 | Cost: 0.2862\n", + "Epoch: 003/050 | Batch 550/891 | Cost: 0.3342\n", + "Epoch: 003/050 | Batch 600/891 | Cost: 0.2348\n", + "Epoch: 003/050 | Batch 650/891 | Cost: 0.3088\n", + "Epoch: 003/050 | Batch 700/891 | Cost: 0.7425\n", + "Epoch: 003/050 | Batch 750/891 | Cost: 0.2534\n", + "Epoch: 003/050 | Batch 800/891 | Cost: 0.3224\n", + "Epoch: 003/050 | Batch 850/891 | Cost: 0.2275\n", + "training accuracy: 90.49%\n", + "valid accuracy: 89.32%\n", + "Time elapsed: 1.40 min\n", + "Epoch: 004/050 | Batch 000/891 | Cost: 0.2450\n", + "Epoch: 004/050 | Batch 050/891 | Cost: 0.2518\n", + "Epoch: 004/050 | Batch 100/891 | Cost: 0.6905\n", + "Epoch: 004/050 | Batch 150/891 | Cost: 0.3877\n", + "Epoch: 004/050 | Batch 200/891 | Cost: 0.2438\n", + "Epoch: 004/050 | Batch 250/891 | Cost: 0.2047\n", + "Epoch: 004/050 | Batch 300/891 | Cost: 0.2984\n", + "Epoch: 004/050 | Batch 350/891 | Cost: 0.4487\n", + "Epoch: 004/050 | Batch 400/891 | Cost: 0.2900\n", + "Epoch: 004/050 | Batch 450/891 | Cost: 0.2992\n", + "Epoch: 004/050 | Batch 500/891 | Cost: 0.2952\n", + "Epoch: 004/050 | Batch 550/891 | Cost: 0.2289\n", + "Epoch: 004/050 | Batch 600/891 | Cost: 0.2467\n", + "Epoch: 004/050 | Batch 650/891 | Cost: 0.1343\n", + "Epoch: 004/050 | Batch 700/891 | Cost: 0.2538\n", + "Epoch: 004/050 | Batch 750/891 | Cost: 0.3580\n", + "Epoch: 004/050 | Batch 800/891 | Cost: 0.3781\n", + "Epoch: 004/050 | Batch 850/891 | Cost: 0.2254\n", + "training accuracy: 91.43%\n", + "valid accuracy: 90.47%\n", + "Time elapsed: 1.86 min\n", + "Epoch: 005/050 | Batch 000/891 | Cost: 0.4273\n", + "Epoch: 005/050 | Batch 050/891 | Cost: 0.3250\n", + "Epoch: 005/050 | Batch 100/891 | Cost: 0.4769\n", + "Epoch: 005/050 | Batch 150/891 | Cost: 0.3298\n", + "Epoch: 005/050 | Batch 200/891 | Cost: 0.3183\n", + "Epoch: 005/050 | Batch 250/891 | Cost: 0.2533\n", + "Epoch: 005/050 | Batch 300/891 | Cost: 0.2897\n", + "Epoch: 005/050 | Batch 350/891 | Cost: 0.2772\n", + "Epoch: 005/050 | Batch 400/891 | Cost: 0.3040\n", + "Epoch: 005/050 | Batch 450/891 | Cost: 0.2332\n", + "Epoch: 005/050 | Batch 500/891 | Cost: 0.2608\n", + "Epoch: 005/050 | Batch 550/891 | Cost: 0.2563\n", + "Epoch: 005/050 | Batch 600/891 | Cost: 0.3264\n", + "Epoch: 005/050 | Batch 650/891 | Cost: 0.2695\n", + "Epoch: 005/050 | Batch 700/891 | Cost: 0.4137\n", + "Epoch: 005/050 | Batch 750/891 | Cost: 0.2787\n", + "Epoch: 005/050 | Batch 800/891 | Cost: 0.3102\n", + "Epoch: 005/050 | Batch 850/891 | Cost: 0.2707\n", + "training accuracy: 92.33%\n", + "valid accuracy: 90.92%\n", + "Time elapsed: 2.34 min\n", + "Epoch: 006/050 | Batch 000/891 | Cost: 0.5286\n", + "Epoch: 006/050 | Batch 050/891 | Cost: 0.1996\n", + "Epoch: 006/050 | Batch 100/891 | Cost: 0.3859\n", + "Epoch: 006/050 | Batch 150/891 | Cost: 0.2322\n", + "Epoch: 006/050 | Batch 200/891 | Cost: 0.2821\n", + "Epoch: 006/050 | Batch 250/891 | Cost: 0.3530\n", + "Epoch: 006/050 | Batch 300/891 | Cost: 0.3880\n", + "Epoch: 006/050 | Batch 350/891 | Cost: 0.4259\n", + "Epoch: 006/050 | Batch 400/891 | Cost: 0.3522\n", + "Epoch: 006/050 | Batch 450/891 | Cost: 0.3299\n", + "Epoch: 006/050 | Batch 500/891 | Cost: 0.3318\n", + "Epoch: 006/050 | Batch 550/891 | Cost: 0.3139\n", + "Epoch: 006/050 | Batch 600/891 | Cost: 0.2604\n", + "Epoch: 006/050 | Batch 650/891 | Cost: 0.2049\n", + "Epoch: 006/050 | Batch 700/891 | Cost: 0.2948\n", + "Epoch: 006/050 | Batch 750/891 | Cost: 0.2000\n", + "Epoch: 006/050 | Batch 800/891 | Cost: 0.1694\n", + "Epoch: 006/050 | Batch 850/891 | Cost: 0.3553\n", + "training accuracy: 92.73%\n", + "valid accuracy: 91.20%\n", + "Time elapsed: 2.81 min\n", + "Epoch: 007/050 | Batch 000/891 | Cost: 0.2787\n", + "Epoch: 007/050 | Batch 050/891 | Cost: 0.2766\n", + "Epoch: 007/050 | Batch 100/891 | Cost: 0.3586\n", + "Epoch: 007/050 | Batch 150/891 | Cost: 0.2167\n", + "Epoch: 007/050 | Batch 200/891 | Cost: 0.2809\n", + "Epoch: 007/050 | Batch 250/891 | Cost: 0.1589\n", + "Epoch: 007/050 | Batch 300/891 | Cost: 0.2980\n", + "Epoch: 007/050 | Batch 350/891 | Cost: 0.2061\n", + "Epoch: 007/050 | Batch 400/891 | Cost: 0.2757\n", + "Epoch: 007/050 | Batch 450/891 | Cost: 0.2706\n", + "Epoch: 007/050 | Batch 500/891 | Cost: 0.1621\n", + "Epoch: 007/050 | Batch 550/891 | Cost: 0.2763\n", + "Epoch: 007/050 | Batch 600/891 | Cost: 0.2122\n", + "Epoch: 007/050 | Batch 650/891 | Cost: 0.3193\n", + "Epoch: 007/050 | Batch 700/891 | Cost: 0.3161\n", + "Epoch: 007/050 | Batch 750/891 | Cost: 0.5697\n", + "Epoch: 007/050 | Batch 800/891 | Cost: 0.2462\n", + "Epoch: 007/050 | Batch 850/891 | Cost: 0.4072\n", + "training accuracy: 93.18%\n", + "valid accuracy: 91.13%\n", + "Time elapsed: 3.29 min\n", + "Epoch: 008/050 | Batch 000/891 | Cost: 0.1531\n", + "Epoch: 008/050 | Batch 050/891 | Cost: 0.2815\n", + "Epoch: 008/050 | Batch 100/891 | Cost: 0.1890\n", + "Epoch: 008/050 | Batch 150/891 | Cost: 0.3430\n", + "Epoch: 008/050 | Batch 200/891 | Cost: 0.3179\n", + "Epoch: 008/050 | Batch 250/891 | Cost: 0.1990\n", + "Epoch: 008/050 | Batch 300/891 | Cost: 0.2313\n", + "Epoch: 008/050 | Batch 350/891 | Cost: 0.1431\n", + "Epoch: 008/050 | Batch 400/891 | Cost: 0.1857\n", + "Epoch: 008/050 | Batch 450/891 | Cost: 0.3604\n", + "Epoch: 008/050 | Batch 500/891 | Cost: 0.3531\n", + "Epoch: 008/050 | Batch 550/891 | Cost: 0.2136\n", + "Epoch: 008/050 | Batch 600/891 | Cost: 0.3887\n", + "Epoch: 008/050 | Batch 650/891 | Cost: 0.2011\n", + "Epoch: 008/050 | Batch 700/891 | Cost: 0.1803\n", + "Epoch: 008/050 | Batch 750/891 | Cost: 0.3328\n", + "Epoch: 008/050 | Batch 800/891 | Cost: 0.2284\n", + "Epoch: 008/050 | Batch 850/891 | Cost: 0.1928\n", + "training accuracy: 93.41%\n", + "valid accuracy: 91.22%\n", + "Time elapsed: 3.77 min\n", + "Epoch: 009/050 | Batch 000/891 | Cost: 0.2629\n", + "Epoch: 009/050 | Batch 050/891 | Cost: 0.2781\n", + "Epoch: 009/050 | Batch 100/891 | Cost: 0.2318\n", + "Epoch: 009/050 | Batch 150/891 | Cost: 0.2701\n", + "Epoch: 009/050 | Batch 200/891 | Cost: 0.1944\n", + "Epoch: 009/050 | Batch 250/891 | Cost: 0.3229\n", + "Epoch: 009/050 | Batch 300/891 | Cost: 0.2979\n", + "Epoch: 009/050 | Batch 350/891 | Cost: 0.2095\n", + "Epoch: 009/050 | Batch 400/891 | Cost: 0.1358\n", + "Epoch: 009/050 | Batch 450/891 | Cost: 0.2221\n", + "Epoch: 009/050 | Batch 500/891 | Cost: 0.1896\n", + "Epoch: 009/050 | Batch 550/891 | Cost: 0.2059\n", + "Epoch: 009/050 | Batch 600/891 | Cost: 0.2914\n", + "Epoch: 009/050 | Batch 650/891 | Cost: 0.4117\n", + "Epoch: 009/050 | Batch 700/891 | Cost: 0.2545\n", + "Epoch: 009/050 | Batch 750/891 | Cost: 0.3608\n", + "Epoch: 009/050 | Batch 800/891 | Cost: 0.2593\n", + "Epoch: 009/050 | Batch 850/891 | Cost: 0.1308\n", + "training accuracy: 93.86%\n", + "valid accuracy: 91.33%\n", + "Time elapsed: 4.26 min\n", + "Epoch: 010/050 | Batch 000/891 | Cost: 0.1087\n", + "Epoch: 010/050 | Batch 050/891 | Cost: 0.1921\n", + "Epoch: 010/050 | Batch 100/891 | Cost: 0.1257\n", + "Epoch: 010/050 | Batch 150/891 | Cost: 0.4087\n", + "Epoch: 010/050 | Batch 200/891 | Cost: 0.2603\n", + "Epoch: 010/050 | Batch 250/891 | Cost: 0.1607\n", + "Epoch: 010/050 | Batch 300/891 | Cost: 0.2791\n", + "Epoch: 010/050 | Batch 350/891 | Cost: 0.1774\n", + "Epoch: 010/050 | Batch 400/891 | Cost: 0.5015\n", + "Epoch: 010/050 | Batch 450/891 | Cost: 0.2276\n", + "Epoch: 010/050 | Batch 500/891 | Cost: 0.2954\n", + "Epoch: 010/050 | Batch 550/891 | Cost: 0.1906\n", + "Epoch: 010/050 | Batch 600/891 | Cost: 0.2464\n", + "Epoch: 010/050 | Batch 650/891 | Cost: 0.2425\n", + "Epoch: 010/050 | Batch 700/891 | Cost: 0.2000\n", + "Epoch: 010/050 | Batch 750/891 | Cost: 0.2981\n", + "Epoch: 010/050 | Batch 800/891 | Cost: 0.2060\n", + "Epoch: 010/050 | Batch 850/891 | Cost: 0.2032\n", + "training accuracy: 94.30%\n", + "valid accuracy: 91.70%\n", + "Time elapsed: 4.74 min\n", + "Epoch: 011/050 | Batch 000/891 | Cost: 0.2229\n", + "Epoch: 011/050 | Batch 050/891 | Cost: 0.2725\n", + "Epoch: 011/050 | Batch 100/891 | Cost: 0.1801\n", + "Epoch: 011/050 | Batch 150/891 | Cost: 0.2125\n", + "Epoch: 011/050 | Batch 200/891 | Cost: 0.1482\n", + "Epoch: 011/050 | Batch 250/891 | Cost: 0.2237\n", + "Epoch: 011/050 | Batch 300/891 | Cost: 0.1581\n", + "Epoch: 011/050 | Batch 350/891 | Cost: 0.3981\n", + "Epoch: 011/050 | Batch 400/891 | Cost: 0.2683\n", + "Epoch: 011/050 | Batch 450/891 | Cost: 0.2471\n", + "Epoch: 011/050 | Batch 500/891 | Cost: 0.1495\n", + "Epoch: 011/050 | Batch 550/891 | Cost: 0.2281\n", + "Epoch: 011/050 | Batch 600/891 | Cost: 0.2023\n", + "Epoch: 011/050 | Batch 650/891 | Cost: 0.1069\n", + "Epoch: 011/050 | Batch 700/891 | Cost: 0.1906\n", + "Epoch: 011/050 | Batch 750/891 | Cost: 0.2770\n", + "Epoch: 011/050 | Batch 800/891 | Cost: 0.1736\n", + "Epoch: 011/050 | Batch 850/891 | Cost: 0.1480\n", + "training accuracy: 94.57%\n", + "valid accuracy: 91.77%\n", + "Time elapsed: 5.23 min\n", + "Epoch: 012/050 | Batch 000/891 | Cost: 0.1419\n", + "Epoch: 012/050 | Batch 050/891 | Cost: 0.2082\n", + "Epoch: 012/050 | Batch 100/891 | Cost: 0.1527\n", + "Epoch: 012/050 | Batch 150/891 | Cost: 0.1564\n", + "Epoch: 012/050 | Batch 200/891 | Cost: 0.2391\n", + "Epoch: 012/050 | Batch 250/891 | Cost: 0.3568\n", + "Epoch: 012/050 | Batch 300/891 | Cost: 0.0926\n", + "Epoch: 012/050 | Batch 350/891 | Cost: 0.1798\n", + "Epoch: 012/050 | Batch 400/891 | Cost: 0.2591\n", + "Epoch: 012/050 | Batch 450/891 | Cost: 0.2005\n", + "Epoch: 012/050 | Batch 500/891 | Cost: 0.1461\n", + "Epoch: 012/050 | Batch 550/891 | Cost: 0.2099\n", + "Epoch: 012/050 | Batch 600/891 | Cost: 0.1473\n", + "Epoch: 012/050 | Batch 650/891 | Cost: 0.2052\n", + "Epoch: 012/050 | Batch 700/891 | Cost: 0.2090\n", + "Epoch: 012/050 | Batch 750/891 | Cost: 0.3133\n", + "Epoch: 012/050 | Batch 800/891 | Cost: 0.0936\n", + "Epoch: 012/050 | Batch 850/891 | Cost: 0.1964\n", + "training accuracy: 94.91%\n", + "valid accuracy: 91.92%\n", + "Time elapsed: 5.71 min\n", + "Epoch: 013/050 | Batch 000/891 | Cost: 0.1882\n", + "Epoch: 013/050 | Batch 050/891 | Cost: 0.1726\n", + "Epoch: 013/050 | Batch 100/891 | Cost: 0.2273\n", + "Epoch: 013/050 | Batch 150/891 | Cost: 0.4143\n", + "Epoch: 013/050 | Batch 200/891 | Cost: 0.1912\n", + "Epoch: 013/050 | Batch 250/891 | Cost: 0.1610\n", + "Epoch: 013/050 | Batch 300/891 | Cost: 0.2238\n", + "Epoch: 013/050 | Batch 350/891 | Cost: 0.3671\n", + "Epoch: 013/050 | Batch 400/891 | Cost: 0.1471\n", + "Epoch: 013/050 | Batch 450/891 | Cost: 0.2440\n", + "Epoch: 013/050 | Batch 500/891 | Cost: 0.2701\n", + "Epoch: 013/050 | Batch 550/891 | Cost: 0.2684\n", + "Epoch: 013/050 | Batch 600/891 | Cost: 0.1602\n", + "Epoch: 013/050 | Batch 650/891 | Cost: 0.2128\n", + "Epoch: 013/050 | Batch 700/891 | Cost: 0.0978\n", + "Epoch: 013/050 | Batch 750/891 | Cost: 0.2017\n", + "Epoch: 013/050 | Batch 800/891 | Cost: 0.0781\n", + "Epoch: 013/050 | Batch 850/891 | Cost: 0.2742\n", + "training accuracy: 95.13%\n", + "valid accuracy: 91.83%\n", + "Time elapsed: 6.19 min\n", + "Epoch: 014/050 | Batch 000/891 | Cost: 0.1590\n", + "Epoch: 014/050 | Batch 050/891 | Cost: 0.1685\n", + "Epoch: 014/050 | Batch 100/891 | Cost: 0.2997\n", + "Epoch: 014/050 | Batch 150/891 | Cost: 0.0779\n", + "Epoch: 014/050 | Batch 200/891 | Cost: 0.1422\n", + "Epoch: 014/050 | Batch 250/891 | Cost: 0.2610\n", + "Epoch: 014/050 | Batch 300/891 | Cost: 0.2471\n", + "Epoch: 014/050 | Batch 350/891 | Cost: 0.1126\n", + "Epoch: 014/050 | Batch 400/891 | Cost: 0.5214\n", + "Epoch: 014/050 | Batch 450/891 | Cost: 0.1805\n", + "Epoch: 014/050 | Batch 500/891 | Cost: 0.3690\n", + "Epoch: 014/050 | Batch 550/891 | Cost: 0.1889\n", + "Epoch: 014/050 | Batch 600/891 | Cost: 0.2583\n", + "Epoch: 014/050 | Batch 650/891 | Cost: 0.1955\n", + "Epoch: 014/050 | Batch 700/891 | Cost: 0.3968\n", + "Epoch: 014/050 | Batch 750/891 | Cost: 0.4153\n", + "Epoch: 014/050 | Batch 800/891 | Cost: 0.2386\n", + "Epoch: 014/050 | Batch 850/891 | Cost: 0.2618\n", + "training accuracy: 95.26%\n", + "valid accuracy: 91.82%\n", + "Time elapsed: 6.68 min\n", + "Epoch: 015/050 | Batch 000/891 | Cost: 0.2095\n", + "Epoch: 015/050 | Batch 050/891 | Cost: 0.1248\n", + "Epoch: 015/050 | Batch 100/891 | Cost: 0.2129\n", + "Epoch: 015/050 | Batch 150/891 | Cost: 0.1529\n", + "Epoch: 015/050 | Batch 200/891 | Cost: 0.1211\n", + "Epoch: 015/050 | Batch 250/891 | Cost: 0.2485\n", + "Epoch: 015/050 | Batch 300/891 | Cost: 0.1596\n", + "Epoch: 015/050 | Batch 350/891 | Cost: 0.2131\n", + "Epoch: 015/050 | Batch 400/891 | Cost: 0.2655\n", + "Epoch: 015/050 | Batch 450/891 | Cost: 0.1532\n", + "Epoch: 015/050 | Batch 500/891 | Cost: 0.1442\n", + "Epoch: 015/050 | Batch 550/891 | Cost: 0.2170\n", + "Epoch: 015/050 | Batch 600/891 | Cost: 0.2097\n", + "Epoch: 015/050 | Batch 650/891 | Cost: 0.1731\n", + "Epoch: 015/050 | Batch 700/891 | Cost: 0.2049\n", + "Epoch: 015/050 | Batch 750/891 | Cost: 0.1335\n", + "Epoch: 015/050 | Batch 800/891 | Cost: 0.1869\n", + "Epoch: 015/050 | Batch 850/891 | Cost: 0.1313\n", + "training accuracy: 95.62%\n", + "valid accuracy: 91.93%\n", + "Time elapsed: 7.17 min\n", + "Epoch: 016/050 | Batch 000/891 | Cost: 0.2243\n", + "Epoch: 016/050 | Batch 050/891 | Cost: 0.1787\n", + "Epoch: 016/050 | Batch 100/891 | Cost: 0.0720\n", + "Epoch: 016/050 | Batch 150/891 | Cost: 0.1693\n", + "Epoch: 016/050 | Batch 200/891 | Cost: 0.0990\n", + "Epoch: 016/050 | Batch 250/891 | Cost: 0.2836\n", + "Epoch: 016/050 | Batch 300/891 | Cost: 0.1295\n", + "Epoch: 016/050 | Batch 350/891 | Cost: 0.0999\n", + "Epoch: 016/050 | Batch 400/891 | Cost: 0.1612\n", + "Epoch: 016/050 | Batch 450/891 | Cost: 0.2436\n", + "Epoch: 016/050 | Batch 500/891 | Cost: 0.2344\n", + "Epoch: 016/050 | Batch 550/891 | Cost: 0.2931\n", + "Epoch: 016/050 | Batch 600/891 | Cost: 0.0864\n", + "Epoch: 016/050 | Batch 650/891 | Cost: 0.2007\n", + "Epoch: 016/050 | Batch 700/891 | Cost: 0.1101\n", + "Epoch: 016/050 | Batch 750/891 | Cost: 0.2093\n", + "Epoch: 016/050 | Batch 800/891 | Cost: 0.1148\n", + "Epoch: 016/050 | Batch 850/891 | Cost: 0.1621\n", + "training accuracy: 95.66%\n", + "valid accuracy: 91.53%\n", + "Time elapsed: 7.65 min\n", + "Epoch: 017/050 | Batch 000/891 | Cost: 0.3486\n", + "Epoch: 017/050 | Batch 050/891 | Cost: 0.1839\n", + "Epoch: 017/050 | Batch 100/891 | Cost: 0.0831\n", + "Epoch: 017/050 | Batch 150/891 | Cost: 0.1529\n", + "Epoch: 017/050 | Batch 200/891 | Cost: 0.2675\n", + "Epoch: 017/050 | Batch 250/891 | Cost: 0.1468\n", + "Epoch: 017/050 | Batch 300/891 | Cost: 0.1797\n", + "Epoch: 017/050 | Batch 350/891 | Cost: 0.1586\n", + "Epoch: 017/050 | Batch 400/891 | Cost: 0.1128\n", + "Epoch: 017/050 | Batch 450/891 | Cost: 0.1678\n", + "Epoch: 017/050 | Batch 500/891 | Cost: 0.1740\n", + "Epoch: 017/050 | Batch 550/891 | Cost: 0.2684\n", + "Epoch: 017/050 | Batch 600/891 | Cost: 0.1596\n", + "Epoch: 017/050 | Batch 650/891 | Cost: 0.2647\n", + "Epoch: 017/050 | Batch 700/891 | Cost: 0.1738\n", + "Epoch: 017/050 | Batch 750/891 | Cost: 0.2119\n", + "Epoch: 017/050 | Batch 800/891 | Cost: 0.1385\n", + "Epoch: 017/050 | Batch 850/891 | Cost: 0.1648\n", + "training accuracy: 95.90%\n", + "valid accuracy: 91.77%\n", + "Time elapsed: 8.14 min\n", + "Epoch: 018/050 | Batch 000/891 | Cost: 0.1678\n", + "Epoch: 018/050 | Batch 050/891 | Cost: 0.1260\n", + "Epoch: 018/050 | Batch 100/891 | Cost: 0.1912\n", + "Epoch: 018/050 | Batch 150/891 | Cost: 0.1299\n", + "Epoch: 018/050 | Batch 200/891 | Cost: 0.1702\n", + "Epoch: 018/050 | Batch 250/891 | Cost: 0.1456\n", + "Epoch: 018/050 | Batch 300/891 | Cost: 0.1284\n", + "Epoch: 018/050 | Batch 350/891 | Cost: 0.2763\n", + "Epoch: 018/050 | Batch 400/891 | Cost: 0.0950\n", + "Epoch: 018/050 | Batch 450/891 | Cost: 0.1417\n", + "Epoch: 018/050 | Batch 500/891 | Cost: 0.2453\n", + "Epoch: 018/050 | Batch 550/891 | Cost: 0.2603\n", + "Epoch: 018/050 | Batch 600/891 | Cost: 0.2635\n", + "Epoch: 018/050 | Batch 650/891 | Cost: 0.1849\n", + "Epoch: 018/050 | Batch 700/891 | Cost: 0.1742\n", + "Epoch: 018/050 | Batch 750/891 | Cost: 0.1185\n", + "Epoch: 018/050 | Batch 800/891 | Cost: 0.4024\n", + "Epoch: 018/050 | Batch 850/891 | Cost: 0.1221\n", + "training accuracy: 96.03%\n", + "valid accuracy: 91.83%\n", + "Time elapsed: 8.63 min\n", + "Epoch: 019/050 | Batch 000/891 | Cost: 0.1801\n", + "Epoch: 019/050 | Batch 050/891 | Cost: 0.2904\n", + "Epoch: 019/050 | Batch 100/891 | Cost: 0.1423\n", + "Epoch: 019/050 | Batch 150/891 | Cost: 0.2176\n", + "Epoch: 019/050 | Batch 200/891 | Cost: 0.2692\n", + "Epoch: 019/050 | Batch 250/891 | Cost: 0.1769\n", + "Epoch: 019/050 | Batch 300/891 | Cost: 0.1792\n", + "Epoch: 019/050 | Batch 350/891 | Cost: 0.4244\n", + "Epoch: 019/050 | Batch 400/891 | Cost: 0.1208\n", + "Epoch: 019/050 | Batch 450/891 | Cost: 0.3000\n", + "Epoch: 019/050 | Batch 500/891 | Cost: 0.1977\n", + "Epoch: 019/050 | Batch 550/891 | Cost: 0.2125\n", + "Epoch: 019/050 | Batch 600/891 | Cost: 0.1181\n", + "Epoch: 019/050 | Batch 650/891 | Cost: 0.1804\n", + "Epoch: 019/050 | Batch 700/891 | Cost: 0.1098\n", + "Epoch: 019/050 | Batch 750/891 | Cost: 0.2638\n", + "Epoch: 019/050 | Batch 800/891 | Cost: 0.1524\n", + "Epoch: 019/050 | Batch 850/891 | Cost: 0.2061\n", + "training accuracy: 96.25%\n", + "valid accuracy: 91.98%\n", + "Time elapsed: 9.11 min\n", + "Epoch: 020/050 | Batch 000/891 | Cost: 0.1519\n", + "Epoch: 020/050 | Batch 050/891 | Cost: 0.1323\n", + "Epoch: 020/050 | Batch 100/891 | Cost: 0.1637\n", + "Epoch: 020/050 | Batch 150/891 | Cost: 0.2232\n", + "Epoch: 020/050 | Batch 200/891 | Cost: 0.4358\n", + "Epoch: 020/050 | Batch 250/891 | Cost: 0.1855\n", + "Epoch: 020/050 | Batch 300/891 | Cost: 0.2004\n", + "Epoch: 020/050 | Batch 350/891 | Cost: 0.0560\n", + "Epoch: 020/050 | Batch 400/891 | Cost: 0.0841\n", + "Epoch: 020/050 | Batch 450/891 | Cost: 0.0601\n", + "Epoch: 020/050 | Batch 500/891 | Cost: 0.0987\n", + "Epoch: 020/050 | Batch 550/891 | Cost: 0.1021\n", + "Epoch: 020/050 | Batch 600/891 | Cost: 0.4316\n", + "Epoch: 020/050 | Batch 650/891 | Cost: 0.1060\n", + "Epoch: 020/050 | Batch 700/891 | Cost: 0.1655\n", + "Epoch: 020/050 | Batch 750/891 | Cost: 0.1303\n", + "Epoch: 020/050 | Batch 800/891 | Cost: 0.2889\n", + "Epoch: 020/050 | Batch 850/891 | Cost: 0.0948\n", + "training accuracy: 96.46%\n", + "valid accuracy: 91.87%\n", + "Time elapsed: 9.59 min\n", + "Epoch: 021/050 | Batch 000/891 | Cost: 0.1513\n", + "Epoch: 021/050 | Batch 050/891 | Cost: 0.1063\n", + "Epoch: 021/050 | Batch 100/891 | Cost: 0.0808\n", + "Epoch: 021/050 | Batch 150/891 | Cost: 0.1390\n", + "Epoch: 021/050 | Batch 200/891 | Cost: 0.1452\n", + "Epoch: 021/050 | Batch 250/891 | Cost: 0.2100\n", + "Epoch: 021/050 | Batch 300/891 | Cost: 0.1803\n", + "Epoch: 021/050 | Batch 350/891 | Cost: 0.1057\n", + "Epoch: 021/050 | Batch 400/891 | Cost: 0.1293\n", + "Epoch: 021/050 | Batch 450/891 | Cost: 0.1064\n", + "Epoch: 021/050 | Batch 500/891 | Cost: 0.1383\n", + "Epoch: 021/050 | Batch 550/891 | Cost: 0.1331\n", + "Epoch: 021/050 | Batch 600/891 | Cost: 0.2483\n", + "Epoch: 021/050 | Batch 650/891 | Cost: 0.1053\n", + "Epoch: 021/050 | Batch 700/891 | Cost: 0.0852\n", + "Epoch: 021/050 | Batch 750/891 | Cost: 0.0939\n", + "Epoch: 021/050 | Batch 800/891 | Cost: 0.1492\n", + "Epoch: 021/050 | Batch 850/891 | Cost: 0.1075\n", + "training accuracy: 96.56%\n", + "valid accuracy: 92.13%\n", + "Time elapsed: 10.33 min\n", + "Epoch: 022/050 | Batch 000/891 | Cost: 0.1810\n", + "Epoch: 022/050 | Batch 050/891 | Cost: 0.1069\n", + "Epoch: 022/050 | Batch 100/891 | Cost: 0.1601\n", + "Epoch: 022/050 | Batch 150/891 | Cost: 0.1092\n", + "Epoch: 022/050 | Batch 200/891 | Cost: 0.2255\n", + "Epoch: 022/050 | Batch 250/891 | Cost: 0.3778\n", + "Epoch: 022/050 | Batch 300/891 | Cost: 0.1875\n", + "Epoch: 022/050 | Batch 350/891 | Cost: 0.1854\n", + "Epoch: 022/050 | Batch 400/891 | Cost: 0.3620\n", + "Epoch: 022/050 | Batch 450/891 | Cost: 0.1210\n", + "Epoch: 022/050 | Batch 500/891 | Cost: 0.0647\n", + "Epoch: 022/050 | Batch 550/891 | Cost: 0.2215\n", + "Epoch: 022/050 | Batch 600/891 | Cost: 0.1141\n", + "Epoch: 022/050 | Batch 650/891 | Cost: 0.1765\n", + "Epoch: 022/050 | Batch 700/891 | Cost: 0.1067\n", + "Epoch: 022/050 | Batch 750/891 | Cost: 0.1907\n", + "Epoch: 022/050 | Batch 800/891 | Cost: 0.1374\n", + "Epoch: 022/050 | Batch 850/891 | Cost: 0.1366\n", + "training accuracy: 96.75%\n", + "valid accuracy: 91.88%\n", + "Time elapsed: 11.35 min\n", + "Epoch: 023/050 | Batch 000/891 | Cost: 0.0993\n", + "Epoch: 023/050 | Batch 050/891 | Cost: 0.1212\n", + "Epoch: 023/050 | Batch 100/891 | Cost: 0.1991\n", + "Epoch: 023/050 | Batch 150/891 | Cost: 0.2732\n", + "Epoch: 023/050 | Batch 200/891 | Cost: 0.2020\n", + "Epoch: 023/050 | Batch 250/891 | Cost: 0.0996\n", + "Epoch: 023/050 | Batch 300/891 | Cost: 0.2931\n", + "Epoch: 023/050 | Batch 350/891 | Cost: 0.1590\n", + "Epoch: 023/050 | Batch 400/891 | Cost: 0.3799\n", + "Epoch: 023/050 | Batch 450/891 | Cost: 0.2423\n", + "Epoch: 023/050 | Batch 500/891 | Cost: 0.1465\n", + "Epoch: 023/050 | Batch 550/891 | Cost: 0.1157\n", + "Epoch: 023/050 | Batch 600/891 | Cost: 0.2244\n", + "Epoch: 023/050 | Batch 650/891 | Cost: 0.1930\n", + "Epoch: 023/050 | Batch 700/891 | Cost: 0.1244\n", + "Epoch: 023/050 | Batch 750/891 | Cost: 0.1410\n", + "Epoch: 023/050 | Batch 800/891 | Cost: 0.1642\n", + "Epoch: 023/050 | Batch 850/891 | Cost: 0.1734\n", + "training accuracy: 96.90%\n", + "valid accuracy: 91.63%\n", + "Time elapsed: 12.39 min\n", + "Epoch: 024/050 | Batch 000/891 | Cost: 0.0709\n", + "Epoch: 024/050 | Batch 050/891 | Cost: 0.1248\n", + "Epoch: 024/050 | Batch 100/891 | Cost: 0.1629\n", + "Epoch: 024/050 | Batch 150/891 | Cost: 0.1777\n", + "Epoch: 024/050 | Batch 200/891 | Cost: 0.2100\n", + "Epoch: 024/050 | Batch 250/891 | Cost: 0.1991\n", + "Epoch: 024/050 | Batch 300/891 | Cost: 0.4561\n", + "Epoch: 024/050 | Batch 350/891 | Cost: 0.1529\n", + "Epoch: 024/050 | Batch 400/891 | Cost: 0.1097\n", + "Epoch: 024/050 | Batch 450/891 | Cost: 0.1213\n", + "Epoch: 024/050 | Batch 500/891 | Cost: 0.1387\n", + "Epoch: 024/050 | Batch 550/891 | Cost: 0.2177\n", + "Epoch: 024/050 | Batch 600/891 | Cost: 0.1028\n", + "Epoch: 024/050 | Batch 650/891 | Cost: 0.2664\n", + "Epoch: 024/050 | Batch 700/891 | Cost: 0.0694\n", + "Epoch: 024/050 | Batch 750/891 | Cost: 0.0847\n", + "Epoch: 024/050 | Batch 800/891 | Cost: 0.1983\n", + "Epoch: 024/050 | Batch 850/891 | Cost: 0.2498\n", + "training accuracy: 97.16%\n", + "valid accuracy: 91.93%\n", + "Time elapsed: 13.42 min\n", + "Epoch: 025/050 | Batch 000/891 | Cost: 0.1991\n", + "Epoch: 025/050 | Batch 050/891 | Cost: 0.0666\n", + "Epoch: 025/050 | Batch 100/891 | Cost: 0.1780\n", + "Epoch: 025/050 | Batch 150/891 | Cost: 0.1563\n", + "Epoch: 025/050 | Batch 200/891 | Cost: 0.0882\n", + "Epoch: 025/050 | Batch 250/891 | Cost: 0.2989\n", + "Epoch: 025/050 | Batch 300/891 | Cost: 0.1824\n", + "Epoch: 025/050 | Batch 350/891 | Cost: 0.2966\n", + "Epoch: 025/050 | Batch 400/891 | Cost: 0.2031\n", + "Epoch: 025/050 | Batch 450/891 | Cost: 0.1180\n", + "Epoch: 025/050 | Batch 500/891 | Cost: 0.3109\n", + "Epoch: 025/050 | Batch 550/891 | Cost: 0.1684\n", + "Epoch: 025/050 | Batch 600/891 | Cost: 0.0875\n", + "Epoch: 025/050 | Batch 650/891 | Cost: 0.1391\n", + "Epoch: 025/050 | Batch 700/891 | Cost: 0.1274\n", + "Epoch: 025/050 | Batch 750/891 | Cost: 0.2153\n", + "Epoch: 025/050 | Batch 800/891 | Cost: 0.1216\n", + "Epoch: 025/050 | Batch 850/891 | Cost: 0.1828\n", + "training accuracy: 97.05%\n", + "valid accuracy: 91.38%\n", + "Time elapsed: 14.47 min\n", + "Epoch: 026/050 | Batch 000/891 | Cost: 0.1344\n", + "Epoch: 026/050 | Batch 050/891 | Cost: 0.2940\n", + "Epoch: 026/050 | Batch 100/891 | Cost: 0.1692\n", + "Epoch: 026/050 | Batch 150/891 | Cost: 0.1281\n", + "Epoch: 026/050 | Batch 200/891 | Cost: 0.1737\n", + "Epoch: 026/050 | Batch 250/891 | Cost: 0.2194\n", + "Epoch: 026/050 | Batch 300/891 | Cost: 0.3692\n", + "Epoch: 026/050 | Batch 350/891 | Cost: 0.2095\n", + "Epoch: 026/050 | Batch 400/891 | Cost: 0.2085\n", + "Epoch: 026/050 | Batch 450/891 | Cost: 0.2011\n", + "Epoch: 026/050 | Batch 500/891 | Cost: 0.2066\n", + "Epoch: 026/050 | Batch 550/891 | Cost: 0.3383\n", + "Epoch: 026/050 | Batch 600/891 | Cost: 0.2015\n", + "Epoch: 026/050 | Batch 650/891 | Cost: 0.1520\n", + "Epoch: 026/050 | Batch 700/891 | Cost: 0.0984\n", + "Epoch: 026/050 | Batch 750/891 | Cost: 0.0933\n", + "Epoch: 026/050 | Batch 800/891 | Cost: 0.2503\n", + "Epoch: 026/050 | Batch 850/891 | Cost: 0.1500\n", + "training accuracy: 97.30%\n", + "valid accuracy: 91.88%\n", + "Time elapsed: 15.54 min\n", + "Epoch: 027/050 | Batch 000/891 | Cost: 0.1133\n", + "Epoch: 027/050 | Batch 050/891 | Cost: 0.0566\n", + "Epoch: 027/050 | Batch 100/891 | Cost: 0.1300\n", + "Epoch: 027/050 | Batch 150/891 | Cost: 0.1017\n", + "Epoch: 027/050 | Batch 200/891 | Cost: 0.1233\n", + "Epoch: 027/050 | Batch 250/891 | Cost: 0.2639\n", + "Epoch: 027/050 | Batch 300/891 | Cost: 0.1417\n", + "Epoch: 027/050 | Batch 350/891 | Cost: 0.1526\n", + "Epoch: 027/050 | Batch 400/891 | Cost: 0.1113\n", + "Epoch: 027/050 | Batch 450/891 | Cost: 0.1807\n", + "Epoch: 027/050 | Batch 500/891 | Cost: 0.2097\n", + "Epoch: 027/050 | Batch 550/891 | Cost: 0.0656\n", + "Epoch: 027/050 | Batch 600/891 | Cost: 0.1461\n", + "Epoch: 027/050 | Batch 650/891 | Cost: 0.0721\n", + "Epoch: 027/050 | Batch 700/891 | Cost: 0.1089\n", + "Epoch: 027/050 | Batch 750/891 | Cost: 0.1491\n", + "Epoch: 027/050 | Batch 800/891 | Cost: 0.2305\n", + "Epoch: 027/050 | Batch 850/891 | Cost: 0.1258\n", + "training accuracy: 97.38%\n", + "valid accuracy: 92.00%\n", + "Time elapsed: 16.61 min\n", + "Epoch: 028/050 | Batch 000/891 | Cost: 0.0894\n", + "Epoch: 028/050 | Batch 050/891 | Cost: 0.1093\n", + "Epoch: 028/050 | Batch 100/891 | Cost: 0.1931\n", + "Epoch: 028/050 | Batch 150/891 | Cost: 0.1843\n", + "Epoch: 028/050 | Batch 200/891 | Cost: 0.1760\n", + "Epoch: 028/050 | Batch 250/891 | Cost: 0.0717\n", + "Epoch: 028/050 | Batch 300/891 | Cost: 0.1854\n", + "Epoch: 028/050 | Batch 350/891 | Cost: 0.1044\n", + "Epoch: 028/050 | Batch 400/891 | Cost: 0.1138\n", + "Epoch: 028/050 | Batch 450/891 | Cost: 0.1639\n", + "Epoch: 028/050 | Batch 500/891 | Cost: 0.1970\n", + "Epoch: 028/050 | Batch 550/891 | Cost: 0.0855\n", + "Epoch: 028/050 | Batch 600/891 | Cost: 0.0979\n", + "Epoch: 028/050 | Batch 650/891 | Cost: 0.1288\n", + "Epoch: 028/050 | Batch 700/891 | Cost: 0.1454\n", + "Epoch: 028/050 | Batch 750/891 | Cost: 0.0631\n", + "Epoch: 028/050 | Batch 800/891 | Cost: 0.1604\n", + "Epoch: 028/050 | Batch 850/891 | Cost: 0.1495\n", + "training accuracy: 97.54%\n", + "valid accuracy: 91.87%\n", + "Time elapsed: 17.68 min\n", + "Epoch: 029/050 | Batch 000/891 | Cost: 0.0644\n", + "Epoch: 029/050 | Batch 050/891 | Cost: 0.0699\n", + "Epoch: 029/050 | Batch 100/891 | Cost: 0.2319\n", + "Epoch: 029/050 | Batch 150/891 | Cost: 0.1196\n", + "Epoch: 029/050 | Batch 200/891 | Cost: 0.0950\n", + "Epoch: 029/050 | Batch 250/891 | Cost: 0.1323\n", + "Epoch: 029/050 | Batch 300/891 | Cost: 0.2933\n", + "Epoch: 029/050 | Batch 350/891 | Cost: 0.1934\n", + "Epoch: 029/050 | Batch 400/891 | Cost: 0.0852\n", + "Epoch: 029/050 | Batch 450/891 | Cost: 0.1402\n", + "Epoch: 029/050 | Batch 500/891 | Cost: 0.2230\n", + "Epoch: 029/050 | Batch 550/891 | Cost: 0.0998\n", + "Epoch: 029/050 | Batch 600/891 | Cost: 0.1782\n", + "Epoch: 029/050 | Batch 650/891 | Cost: 0.3283\n", + "Epoch: 029/050 | Batch 700/891 | Cost: 0.2203\n", + "Epoch: 029/050 | Batch 750/891 | Cost: 0.1579\n", + "Epoch: 029/050 | Batch 800/891 | Cost: 0.1457\n", + "Epoch: 029/050 | Batch 850/891 | Cost: 0.2025\n", + "training accuracy: 97.45%\n", + "valid accuracy: 91.53%\n", + "Time elapsed: 18.74 min\n", + "Epoch: 030/050 | Batch 000/891 | Cost: 0.0462\n", + "Epoch: 030/050 | Batch 050/891 | Cost: 0.1564\n", + "Epoch: 030/050 | Batch 100/891 | Cost: 0.0746\n", + "Epoch: 030/050 | Batch 150/891 | Cost: 0.1384\n", + "Epoch: 030/050 | Batch 200/891 | Cost: 0.2740\n", + "Epoch: 030/050 | Batch 250/891 | Cost: 0.3271\n", + "Epoch: 030/050 | Batch 300/891 | Cost: 0.1764\n", + "Epoch: 030/050 | Batch 350/891 | Cost: 0.1777\n", + "Epoch: 030/050 | Batch 400/891 | Cost: 0.0841\n", + "Epoch: 030/050 | Batch 450/891 | Cost: 0.1597\n", + "Epoch: 030/050 | Batch 500/891 | Cost: 0.1223\n", + "Epoch: 030/050 | Batch 550/891 | Cost: 0.1083\n", + "Epoch: 030/050 | Batch 600/891 | Cost: 0.1478\n", + "Epoch: 030/050 | Batch 650/891 | Cost: 0.2959\n", + "Epoch: 030/050 | Batch 700/891 | Cost: 0.1887\n", + "Epoch: 030/050 | Batch 750/891 | Cost: 0.2498\n", + "Epoch: 030/050 | Batch 800/891 | Cost: 0.1300\n", + "Epoch: 030/050 | Batch 850/891 | Cost: 0.1651\n", + "training accuracy: 97.41%\n", + "valid accuracy: 91.53%\n", + "Time elapsed: 19.80 min\n", + "Epoch: 031/050 | Batch 000/891 | Cost: 0.2204\n", + "Epoch: 031/050 | Batch 050/891 | Cost: 0.0253\n", + "Epoch: 031/050 | Batch 100/891 | Cost: 0.2895\n", + "Epoch: 031/050 | Batch 150/891 | Cost: 0.1715\n", + "Epoch: 031/050 | Batch 200/891 | Cost: 0.1887\n", + "Epoch: 031/050 | Batch 250/891 | Cost: 0.2059\n", + "Epoch: 031/050 | Batch 300/891 | Cost: 0.0932\n", + "Epoch: 031/050 | Batch 350/891 | Cost: 0.1699\n", + "Epoch: 031/050 | Batch 400/891 | Cost: 0.0939\n", + "Epoch: 031/050 | Batch 450/891 | Cost: 0.1887\n", + "Epoch: 031/050 | Batch 500/891 | Cost: 0.1506\n", + "Epoch: 031/050 | Batch 550/891 | Cost: 0.0940\n", + "Epoch: 031/050 | Batch 600/891 | Cost: 0.0522\n", + "Epoch: 031/050 | Batch 650/891 | Cost: 0.0805\n", + "Epoch: 031/050 | Batch 700/891 | Cost: 0.1576\n", + "Epoch: 031/050 | Batch 750/891 | Cost: 0.0976\n", + "Epoch: 031/050 | Batch 800/891 | Cost: 0.2967\n", + "Epoch: 031/050 | Batch 850/891 | Cost: 0.1926\n", + "training accuracy: 97.74%\n", + "valid accuracy: 91.80%\n", + "Time elapsed: 20.79 min\n", + "Epoch: 032/050 | Batch 000/891 | Cost: 0.2118\n", + "Epoch: 032/050 | Batch 050/891 | Cost: 0.1500\n", + "Epoch: 032/050 | Batch 100/891 | Cost: 0.0699\n", + "Epoch: 032/050 | Batch 150/891 | Cost: 0.1424\n", + "Epoch: 032/050 | Batch 200/891 | Cost: 0.2768\n", + "Epoch: 032/050 | Batch 250/891 | Cost: 0.0965\n", + "Epoch: 032/050 | Batch 300/891 | Cost: 0.0836\n", + "Epoch: 032/050 | Batch 350/891 | Cost: 0.1566\n", + "Epoch: 032/050 | Batch 400/891 | Cost: 0.1140\n", + "Epoch: 032/050 | Batch 450/891 | Cost: 0.1286\n", + "Epoch: 032/050 | Batch 500/891 | Cost: 0.1687\n", + "Epoch: 032/050 | Batch 550/891 | Cost: 0.0647\n", + "Epoch: 032/050 | Batch 600/891 | Cost: 0.0885\n", + "Epoch: 032/050 | Batch 650/891 | Cost: 0.0491\n", + "Epoch: 032/050 | Batch 700/891 | Cost: 0.0612\n", + "Epoch: 032/050 | Batch 750/891 | Cost: 0.0645\n", + "Epoch: 032/050 | Batch 800/891 | Cost: 0.2246\n", + "Epoch: 032/050 | Batch 850/891 | Cost: 0.0900\n", + "training accuracy: 97.75%\n", + "valid accuracy: 91.77%\n", + "Time elapsed: 21.75 min\n", + "Epoch: 033/050 | Batch 000/891 | Cost: 0.1070\n", + "Epoch: 033/050 | Batch 050/891 | Cost: 0.1982\n", + "Epoch: 033/050 | Batch 100/891 | Cost: 0.1159\n", + "Epoch: 033/050 | Batch 150/891 | Cost: 0.1398\n", + "Epoch: 033/050 | Batch 200/891 | Cost: 0.0937\n", + "Epoch: 033/050 | Batch 250/891 | Cost: 0.1015\n", + "Epoch: 033/050 | Batch 300/891 | Cost: 0.0945\n", + "Epoch: 033/050 | Batch 350/891 | Cost: 0.0534\n", + "Epoch: 033/050 | Batch 400/891 | Cost: 0.1476\n", + "Epoch: 033/050 | Batch 450/891 | Cost: 0.0937\n", + "Epoch: 033/050 | Batch 500/891 | Cost: 0.2442\n", + "Epoch: 033/050 | Batch 550/891 | Cost: 0.0817\n", + "Epoch: 033/050 | Batch 600/891 | Cost: 0.2181\n", + "Epoch: 033/050 | Batch 650/891 | Cost: 0.2121\n", + "Epoch: 033/050 | Batch 700/891 | Cost: 0.1767\n", + "Epoch: 033/050 | Batch 750/891 | Cost: 0.2248\n", + "Epoch: 033/050 | Batch 800/891 | Cost: 0.1277\n", + "Epoch: 033/050 | Batch 850/891 | Cost: 0.1004\n", + "training accuracy: 97.88%\n", + "valid accuracy: 91.63%\n", + "Time elapsed: 22.69 min\n", + "Epoch: 034/050 | Batch 000/891 | Cost: 0.1261\n", + "Epoch: 034/050 | Batch 050/891 | Cost: 0.1267\n", + "Epoch: 034/050 | Batch 100/891 | Cost: 0.1777\n", + "Epoch: 034/050 | Batch 150/891 | Cost: 0.2866\n", + "Epoch: 034/050 | Batch 200/891 | Cost: 0.0845\n", + "Epoch: 034/050 | Batch 250/891 | Cost: 0.2171\n", + "Epoch: 034/050 | Batch 300/891 | Cost: 0.1906\n", + "Epoch: 034/050 | Batch 350/891 | Cost: 0.1531\n", + "Epoch: 034/050 | Batch 400/891 | Cost: 0.0928\n", + "Epoch: 034/050 | Batch 450/891 | Cost: 0.1674\n", + "Epoch: 034/050 | Batch 500/891 | Cost: 0.2959\n", + "Epoch: 034/050 | Batch 550/891 | Cost: 0.1654\n", + "Epoch: 034/050 | Batch 600/891 | Cost: 0.2238\n", + "Epoch: 034/050 | Batch 650/891 | Cost: 0.1358\n", + "Epoch: 034/050 | Batch 700/891 | Cost: 0.0593\n", + "Epoch: 034/050 | Batch 750/891 | Cost: 0.2061\n", + "Epoch: 034/050 | Batch 800/891 | Cost: 0.0418\n", + "Epoch: 034/050 | Batch 850/891 | Cost: 0.1814\n", + "training accuracy: 97.77%\n", + "valid accuracy: 91.53%\n", + "Time elapsed: 23.67 min\n", + "Epoch: 035/050 | Batch 000/891 | Cost: 0.2832\n", + "Epoch: 035/050 | Batch 050/891 | Cost: 0.0631\n", + "Epoch: 035/050 | Batch 100/891 | Cost: 0.1005\n", + "Epoch: 035/050 | Batch 150/891 | Cost: 0.1677\n", + "Epoch: 035/050 | Batch 200/891 | Cost: 0.0663\n", + "Epoch: 035/050 | Batch 250/891 | Cost: 0.1370\n", + "Epoch: 035/050 | Batch 300/891 | Cost: 0.1260\n", + "Epoch: 035/050 | Batch 350/891 | Cost: 0.1642\n", + "Epoch: 035/050 | Batch 400/891 | Cost: 0.1703\n", + "Epoch: 035/050 | Batch 450/891 | Cost: 0.1147\n", + "Epoch: 035/050 | Batch 500/891 | Cost: 0.1205\n", + "Epoch: 035/050 | Batch 550/891 | Cost: 0.1352\n", + "Epoch: 035/050 | Batch 600/891 | Cost: 0.1017\n", + "Epoch: 035/050 | Batch 650/891 | Cost: 0.2116\n", + "Epoch: 035/050 | Batch 700/891 | Cost: 0.1301\n", + "Epoch: 035/050 | Batch 750/891 | Cost: 0.1565\n", + "Epoch: 035/050 | Batch 800/891 | Cost: 0.0610\n", + "Epoch: 035/050 | Batch 850/891 | Cost: 0.1000\n", + "training accuracy: 98.02%\n", + "valid accuracy: 91.92%\n", + "Time elapsed: 24.75 min\n", + "Epoch: 036/050 | Batch 000/891 | Cost: 0.2945\n", + "Epoch: 036/050 | Batch 050/891 | Cost: 0.0929\n", + "Epoch: 036/050 | Batch 100/891 | Cost: 0.1919\n", + "Epoch: 036/050 | Batch 150/891 | Cost: 0.1328\n", + "Epoch: 036/050 | Batch 200/891 | Cost: 0.0948\n", + "Epoch: 036/050 | Batch 250/891 | Cost: 0.0330\n", + "Epoch: 036/050 | Batch 300/891 | Cost: 0.1418\n", + "Epoch: 036/050 | Batch 350/891 | Cost: 0.3359\n", + "Epoch: 036/050 | Batch 400/891 | Cost: 0.3079\n", + "Epoch: 036/050 | Batch 450/891 | Cost: 0.1771\n", + "Epoch: 036/050 | Batch 500/891 | Cost: 0.0698\n", + "Epoch: 036/050 | Batch 550/891 | Cost: 0.1285\n", + "Epoch: 036/050 | Batch 600/891 | Cost: 0.0174\n", + "Epoch: 036/050 | Batch 650/891 | Cost: 0.1377\n", + "Epoch: 036/050 | Batch 700/891 | Cost: 0.1203\n", + "Epoch: 036/050 | Batch 750/891 | Cost: 0.0861\n", + "Epoch: 036/050 | Batch 800/891 | Cost: 0.0767\n", + "Epoch: 036/050 | Batch 850/891 | Cost: 0.1800\n", + "training accuracy: 97.97%\n", + "valid accuracy: 91.88%\n", + "Time elapsed: 25.82 min\n", + "Epoch: 037/050 | Batch 000/891 | Cost: 0.3566\n", + "Epoch: 037/050 | Batch 050/891 | Cost: 0.1634\n", + "Epoch: 037/050 | Batch 100/891 | Cost: 0.1186\n", + "Epoch: 037/050 | Batch 150/891 | Cost: 0.1233\n", + "Epoch: 037/050 | Batch 200/891 | Cost: 0.1115\n", + "Epoch: 037/050 | Batch 250/891 | Cost: 0.1204\n", + "Epoch: 037/050 | Batch 300/891 | Cost: 0.0447\n", + "Epoch: 037/050 | Batch 350/891 | Cost: 0.1045\n", + "Epoch: 037/050 | Batch 400/891 | Cost: 0.1046\n", + "Epoch: 037/050 | Batch 450/891 | Cost: 0.0250\n", + "Epoch: 037/050 | Batch 500/891 | Cost: 0.0988\n", + "Epoch: 037/050 | Batch 550/891 | Cost: 0.1314\n", + "Epoch: 037/050 | Batch 600/891 | Cost: 0.1060\n", + "Epoch: 037/050 | Batch 650/891 | Cost: 0.1120\n", + "Epoch: 037/050 | Batch 700/891 | Cost: 0.1844\n", + "Epoch: 037/050 | Batch 750/891 | Cost: 0.0897\n", + "Epoch: 037/050 | Batch 800/891 | Cost: 0.2487\n", + "Epoch: 037/050 | Batch 850/891 | Cost: 0.1493\n", + "training accuracy: 97.98%\n", + "valid accuracy: 91.48%\n", + "Time elapsed: 26.89 min\n", + "Epoch: 038/050 | Batch 000/891 | Cost: 0.1361\n", + "Epoch: 038/050 | Batch 050/891 | Cost: 0.1114\n", + "Epoch: 038/050 | Batch 100/891 | Cost: 0.1495\n", + "Epoch: 038/050 | Batch 150/891 | Cost: 0.0973\n", + "Epoch: 038/050 | Batch 200/891 | Cost: 0.1874\n", + "Epoch: 038/050 | Batch 250/891 | Cost: 0.1043\n", + "Epoch: 038/050 | Batch 300/891 | Cost: 0.1514\n", + "Epoch: 038/050 | Batch 350/891 | Cost: 0.2377\n", + "Epoch: 038/050 | Batch 400/891 | Cost: 0.2675\n", + "Epoch: 038/050 | Batch 450/891 | Cost: 0.0705\n", + "Epoch: 038/050 | Batch 500/891 | Cost: 0.1921\n", + "Epoch: 038/050 | Batch 550/891 | Cost: 0.0772\n", + "Epoch: 038/050 | Batch 600/891 | Cost: 0.2542\n", + "Epoch: 038/050 | Batch 650/891 | Cost: 0.0602\n", + "Epoch: 038/050 | Batch 700/891 | Cost: 0.1468\n", + "Epoch: 038/050 | Batch 750/891 | Cost: 0.0620\n", + "Epoch: 038/050 | Batch 800/891 | Cost: 0.1213\n", + "Epoch: 038/050 | Batch 850/891 | Cost: 0.1046\n", + "training accuracy: 98.07%\n", + "valid accuracy: 91.80%\n", + "Time elapsed: 27.93 min\n", + "Epoch: 039/050 | Batch 000/891 | Cost: 0.1133\n", + "Epoch: 039/050 | Batch 050/891 | Cost: 0.1479\n", + "Epoch: 039/050 | Batch 100/891 | Cost: 0.1279\n", + "Epoch: 039/050 | Batch 150/891 | Cost: 0.1508\n", + "Epoch: 039/050 | Batch 200/891 | Cost: 0.1695\n", + "Epoch: 039/050 | Batch 250/891 | Cost: 0.1512\n", + "Epoch: 039/050 | Batch 300/891 | Cost: 0.1059\n", + "Epoch: 039/050 | Batch 350/891 | Cost: 0.0721\n", + "Epoch: 039/050 | Batch 400/891 | Cost: 0.0856\n", + "Epoch: 039/050 | Batch 450/891 | Cost: 0.1215\n", + "Epoch: 039/050 | Batch 500/891 | Cost: 0.0628\n", + "Epoch: 039/050 | Batch 550/891 | Cost: 0.1136\n", + "Epoch: 039/050 | Batch 600/891 | Cost: 0.0866\n", + "Epoch: 039/050 | Batch 650/891 | Cost: 0.0740\n", + "Epoch: 039/050 | Batch 700/891 | Cost: 0.0922\n", + "Epoch: 039/050 | Batch 750/891 | Cost: 0.0684\n", + "Epoch: 039/050 | Batch 800/891 | Cost: 0.1036\n", + "Epoch: 039/050 | Batch 850/891 | Cost: 0.3993\n", + "training accuracy: 98.14%\n", + "valid accuracy: 91.50%\n", + "Time elapsed: 28.98 min\n", + "Epoch: 040/050 | Batch 000/891 | Cost: 0.1712\n", + "Epoch: 040/050 | Batch 050/891 | Cost: 0.1368\n", + "Epoch: 040/050 | Batch 100/891 | Cost: 0.2130\n", + "Epoch: 040/050 | Batch 150/891 | Cost: 0.2074\n", + "Epoch: 040/050 | Batch 200/891 | Cost: 0.1886\n", + "Epoch: 040/050 | Batch 250/891 | Cost: 0.0763\n", + "Epoch: 040/050 | Batch 300/891 | Cost: 0.1250\n", + "Epoch: 040/050 | Batch 350/891 | Cost: 0.0659\n", + "Epoch: 040/050 | Batch 400/891 | Cost: 0.1597\n", + "Epoch: 040/050 | Batch 450/891 | Cost: 0.0973\n", + "Epoch: 040/050 | Batch 500/891 | Cost: 0.1974\n", + "Epoch: 040/050 | Batch 550/891 | Cost: 0.0470\n", + "Epoch: 040/050 | Batch 600/891 | Cost: 0.0981\n", + "Epoch: 040/050 | Batch 650/891 | Cost: 0.2160\n", + "Epoch: 040/050 | Batch 700/891 | Cost: 0.0991\n", + "Epoch: 040/050 | Batch 750/891 | Cost: 0.1553\n", + "Epoch: 040/050 | Batch 800/891 | Cost: 0.2289\n", + "Epoch: 040/050 | Batch 850/891 | Cost: 0.1656\n", + "training accuracy: 98.16%\n", + "valid accuracy: 91.67%\n", + "Time elapsed: 30.01 min\n", + "Epoch: 041/050 | Batch 000/891 | Cost: 0.1532\n", + "Epoch: 041/050 | Batch 050/891 | Cost: 0.1516\n", + "Epoch: 041/050 | Batch 100/891 | Cost: 0.1026\n", + "Epoch: 041/050 | Batch 150/891 | Cost: 0.2094\n", + "Epoch: 041/050 | Batch 200/891 | Cost: 0.0773\n", + "Epoch: 041/050 | Batch 250/891 | Cost: 0.0909\n", + "Epoch: 041/050 | Batch 300/891 | Cost: 0.1079\n", + "Epoch: 041/050 | Batch 350/891 | Cost: 0.2061\n", + "Epoch: 041/050 | Batch 400/891 | Cost: 0.0633\n", + "Epoch: 041/050 | Batch 450/891 | Cost: 0.1377\n", + "Epoch: 041/050 | Batch 500/891 | Cost: 0.2176\n", + "Epoch: 041/050 | Batch 550/891 | Cost: 0.1144\n", + "Epoch: 041/050 | Batch 600/891 | Cost: 0.1907\n", + "Epoch: 041/050 | Batch 650/891 | Cost: 0.1184\n", + "Epoch: 041/050 | Batch 700/891 | Cost: 0.0938\n", + "Epoch: 041/050 | Batch 750/891 | Cost: 0.0866\n", + "Epoch: 041/050 | Batch 800/891 | Cost: 0.1442\n", + "Epoch: 041/050 | Batch 850/891 | Cost: 0.0893\n", + "training accuracy: 98.25%\n", + "valid accuracy: 91.70%\n", + "Time elapsed: 31.05 min\n", + "Epoch: 042/050 | Batch 000/891 | Cost: 0.1878\n", + "Epoch: 042/050 | Batch 050/891 | Cost: 0.1001\n", + "Epoch: 042/050 | Batch 100/891 | Cost: 0.0742\n", + "Epoch: 042/050 | Batch 150/891 | Cost: 0.1685\n", + "Epoch: 042/050 | Batch 200/891 | Cost: 0.0812\n", + "Epoch: 042/050 | Batch 250/891 | Cost: 0.1662\n", + "Epoch: 042/050 | Batch 300/891 | Cost: 0.0969\n", + "Epoch: 042/050 | Batch 350/891 | Cost: 0.1765\n", + "Epoch: 042/050 | Batch 400/891 | Cost: 0.0659\n", + "Epoch: 042/050 | Batch 450/891 | Cost: 0.1227\n", + "Epoch: 042/050 | Batch 500/891 | Cost: 0.0946\n", + "Epoch: 042/050 | Batch 550/891 | Cost: 0.1164\n", + "Epoch: 042/050 | Batch 600/891 | Cost: 0.1121\n", + "Epoch: 042/050 | Batch 650/891 | Cost: 0.1068\n", + "Epoch: 042/050 | Batch 700/891 | Cost: 0.0964\n", + "Epoch: 042/050 | Batch 750/891 | Cost: 0.1052\n", + "Epoch: 042/050 | Batch 800/891 | Cost: 0.0914\n", + "Epoch: 042/050 | Batch 850/891 | Cost: 0.1908\n", + "training accuracy: 98.24%\n", + "valid accuracy: 91.52%\n", + "Time elapsed: 32.08 min\n", + "Epoch: 043/050 | Batch 000/891 | Cost: 0.1148\n", + "Epoch: 043/050 | Batch 050/891 | Cost: 0.0874\n", + "Epoch: 043/050 | Batch 100/891 | Cost: 0.1539\n", + "Epoch: 043/050 | Batch 150/891 | Cost: 0.1270\n", + "Epoch: 043/050 | Batch 200/891 | Cost: 0.0444\n", + "Epoch: 043/050 | Batch 250/891 | Cost: 0.0705\n", + "Epoch: 043/050 | Batch 300/891 | Cost: 0.1335\n", + "Epoch: 043/050 | Batch 350/891 | Cost: 0.2058\n", + "Epoch: 043/050 | Batch 400/891 | Cost: 0.1839\n", + "Epoch: 043/050 | Batch 450/891 | Cost: 0.1798\n", + "Epoch: 043/050 | Batch 500/891 | Cost: 0.1855\n", + "Epoch: 043/050 | Batch 550/891 | Cost: 0.1608\n", + "Epoch: 043/050 | Batch 600/891 | Cost: 0.1785\n", + "Epoch: 043/050 | Batch 650/891 | Cost: 0.1823\n", + "Epoch: 043/050 | Batch 700/891 | Cost: 0.1660\n", + "Epoch: 043/050 | Batch 750/891 | Cost: 0.2193\n", + "Epoch: 043/050 | Batch 800/891 | Cost: 0.1133\n", + "Epoch: 043/050 | Batch 850/891 | Cost: 0.0708\n", + "training accuracy: 98.25%\n", + "valid accuracy: 91.60%\n", + "Time elapsed: 33.12 min\n", + "Epoch: 044/050 | Batch 000/891 | Cost: 0.1061\n", + "Epoch: 044/050 | Batch 050/891 | Cost: 0.1410\n", + "Epoch: 044/050 | Batch 100/891 | Cost: 0.0963\n", + "Epoch: 044/050 | Batch 150/891 | Cost: 0.0455\n", + "Epoch: 044/050 | Batch 200/891 | Cost: 0.1148\n", + "Epoch: 044/050 | Batch 250/891 | Cost: 0.0956\n", + "Epoch: 044/050 | Batch 300/891 | Cost: 0.1357\n", + "Epoch: 044/050 | Batch 350/891 | Cost: 0.0914\n", + "Epoch: 044/050 | Batch 400/891 | Cost: 0.1779\n", + "Epoch: 044/050 | Batch 450/891 | Cost: 0.0951\n", + "Epoch: 044/050 | Batch 500/891 | Cost: 0.0805\n", + "Epoch: 044/050 | Batch 550/891 | Cost: 0.0946\n", + "Epoch: 044/050 | Batch 600/891 | Cost: 0.2519\n", + "Epoch: 044/050 | Batch 650/891 | Cost: 0.0587\n", + "Epoch: 044/050 | Batch 700/891 | Cost: 0.1026\n", + "Epoch: 044/050 | Batch 750/891 | Cost: 0.0970\n", + "Epoch: 044/050 | Batch 800/891 | Cost: 0.1420\n", + "Epoch: 044/050 | Batch 850/891 | Cost: 0.0799\n", + "training accuracy: 98.27%\n", + "valid accuracy: 91.50%\n", + "Time elapsed: 34.09 min\n", + "Epoch: 045/050 | Batch 000/891 | Cost: 0.1535\n", + "Epoch: 045/050 | Batch 050/891 | Cost: 0.1314\n", + "Epoch: 045/050 | Batch 100/891 | Cost: 0.0673\n", + "Epoch: 045/050 | Batch 150/891 | Cost: 0.1049\n", + "Epoch: 045/050 | Batch 200/891 | Cost: 0.0908\n", + "Epoch: 045/050 | Batch 250/891 | Cost: 0.2232\n", + "Epoch: 045/050 | Batch 300/891 | Cost: 0.0698\n", + "Epoch: 045/050 | Batch 350/891 | Cost: 0.0505\n", + "Epoch: 045/050 | Batch 400/891 | Cost: 0.0682\n", + "Epoch: 045/050 | Batch 450/891 | Cost: 0.1018\n", + "Epoch: 045/050 | Batch 500/891 | Cost: 0.0461\n", + "Epoch: 045/050 | Batch 550/891 | Cost: 0.1451\n", + "Epoch: 045/050 | Batch 600/891 | Cost: 0.0264\n", + "Epoch: 045/050 | Batch 650/891 | Cost: 0.0608\n", + "Epoch: 045/050 | Batch 700/891 | Cost: 0.1043\n", + "Epoch: 045/050 | Batch 750/891 | Cost: 0.0882\n", + "Epoch: 045/050 | Batch 800/891 | Cost: 0.1163\n", + "Epoch: 045/050 | Batch 850/891 | Cost: 0.2396\n", + "training accuracy: 98.29%\n", + "valid accuracy: 91.40%\n", + "Time elapsed: 35.03 min\n", + "Epoch: 046/050 | Batch 000/891 | Cost: 0.0788\n", + "Epoch: 046/050 | Batch 050/891 | Cost: 0.0304\n", + "Epoch: 046/050 | Batch 100/891 | Cost: 0.0826\n", + "Epoch: 046/050 | Batch 150/891 | Cost: 0.1860\n", + "Epoch: 046/050 | Batch 200/891 | Cost: 0.1872\n", + "Epoch: 046/050 | Batch 250/891 | Cost: 0.0610\n", + "Epoch: 046/050 | Batch 300/891 | Cost: 0.1037\n", + "Epoch: 046/050 | Batch 350/891 | Cost: 0.1565\n", + "Epoch: 046/050 | Batch 400/891 | Cost: 0.1976\n", + "Epoch: 046/050 | Batch 450/891 | Cost: 0.1081\n", + "Epoch: 046/050 | Batch 500/891 | Cost: 0.1374\n", + "Epoch: 046/050 | Batch 550/891 | Cost: 0.0744\n", + "Epoch: 046/050 | Batch 600/891 | Cost: 0.0795\n", + "Epoch: 046/050 | Batch 650/891 | Cost: 0.1045\n", + "Epoch: 046/050 | Batch 700/891 | Cost: 0.2454\n", + "Epoch: 046/050 | Batch 750/891 | Cost: 0.1897\n", + "Epoch: 046/050 | Batch 800/891 | Cost: 0.0899\n", + "Epoch: 046/050 | Batch 850/891 | Cost: 0.1644\n", + "training accuracy: 98.52%\n", + "valid accuracy: 91.80%\n", + "Time elapsed: 35.97 min\n", + "Epoch: 047/050 | Batch 000/891 | Cost: 0.0844\n", + "Epoch: 047/050 | Batch 050/891 | Cost: 0.1276\n", + "Epoch: 047/050 | Batch 100/891 | Cost: 0.1050\n", + "Epoch: 047/050 | Batch 150/891 | Cost: 0.0994\n", + "Epoch: 047/050 | Batch 200/891 | Cost: 0.0310\n", + "Epoch: 047/050 | Batch 250/891 | Cost: 0.1233\n", + "Epoch: 047/050 | Batch 300/891 | Cost: 0.1956\n", + "Epoch: 047/050 | Batch 350/891 | Cost: 0.1355\n", + "Epoch: 047/050 | Batch 400/891 | Cost: 0.0901\n", + "Epoch: 047/050 | Batch 450/891 | Cost: 0.1141\n", + "Epoch: 047/050 | Batch 500/891 | Cost: 0.1127\n", + "Epoch: 047/050 | Batch 550/891 | Cost: 0.1333\n", + "Epoch: 047/050 | Batch 600/891 | Cost: 0.0607\n", + "Epoch: 047/050 | Batch 650/891 | Cost: 0.0458\n", + "Epoch: 047/050 | Batch 700/891 | Cost: 0.0623\n", + "Epoch: 047/050 | Batch 750/891 | Cost: 0.1557\n", + "Epoch: 047/050 | Batch 800/891 | Cost: 0.0998\n", + "Epoch: 047/050 | Batch 850/891 | Cost: 0.1906\n", + "training accuracy: 98.39%\n", + "valid accuracy: 91.62%\n", + "Time elapsed: 36.90 min\n", + "Epoch: 048/050 | Batch 000/891 | Cost: 0.0498\n", + "Epoch: 048/050 | Batch 050/891 | Cost: 0.1280\n", + "Epoch: 048/050 | Batch 100/891 | Cost: 0.3360\n", + "Epoch: 048/050 | Batch 150/891 | Cost: 0.1495\n", + "Epoch: 048/050 | Batch 200/891 | Cost: 0.1255\n", + "Epoch: 048/050 | Batch 250/891 | Cost: 0.0538\n", + "Epoch: 048/050 | Batch 300/891 | Cost: 0.1525\n", + "Epoch: 048/050 | Batch 350/891 | Cost: 0.0628\n", + "Epoch: 048/050 | Batch 400/891 | Cost: 0.0923\n", + "Epoch: 048/050 | Batch 450/891 | Cost: 0.2230\n", + "Epoch: 048/050 | Batch 500/891 | Cost: 0.3083\n", + "Epoch: 048/050 | Batch 550/891 | Cost: 0.0439\n", + "Epoch: 048/050 | Batch 600/891 | Cost: 0.0468\n", + "Epoch: 048/050 | Batch 650/891 | Cost: 0.0583\n", + "Epoch: 048/050 | Batch 700/891 | Cost: 0.1199\n", + "Epoch: 048/050 | Batch 750/891 | Cost: 0.0736\n", + "Epoch: 048/050 | Batch 800/891 | Cost: 0.1704\n", + "Epoch: 048/050 | Batch 850/891 | Cost: 0.1210\n", + "training accuracy: 98.62%\n", + "valid accuracy: 91.67%\n", + "Time elapsed: 37.94 min\n", + "Epoch: 049/050 | Batch 000/891 | Cost: 0.0950\n", + "Epoch: 049/050 | Batch 050/891 | Cost: 0.0561\n", + "Epoch: 049/050 | Batch 100/891 | Cost: 0.0741\n", + "Epoch: 049/050 | Batch 150/891 | Cost: 0.1510\n", + "Epoch: 049/050 | Batch 200/891 | Cost: 0.0725\n", + "Epoch: 049/050 | Batch 250/891 | Cost: 0.1095\n", + "Epoch: 049/050 | Batch 300/891 | Cost: 0.0607\n", + "Epoch: 049/050 | Batch 350/891 | Cost: 0.1911\n", + "Epoch: 049/050 | Batch 400/891 | Cost: 0.0869\n", + "Epoch: 049/050 | Batch 450/891 | Cost: 0.0695\n", + "Epoch: 049/050 | Batch 500/891 | Cost: 0.1631\n", + "Epoch: 049/050 | Batch 550/891 | Cost: 0.2730\n", + "Epoch: 049/050 | Batch 600/891 | Cost: 0.0997\n", + "Epoch: 049/050 | Batch 650/891 | Cost: 0.0588\n", + "Epoch: 049/050 | Batch 700/891 | Cost: 0.0969\n", + "Epoch: 049/050 | Batch 750/891 | Cost: 0.1929\n", + "Epoch: 049/050 | Batch 800/891 | Cost: 0.0639\n", + "Epoch: 049/050 | Batch 850/891 | Cost: 0.1441\n", + "training accuracy: 98.67%\n", + "valid accuracy: 91.80%\n", + "Time elapsed: 38.98 min\n", + "Epoch: 050/050 | Batch 000/891 | Cost: 0.0646\n", + "Epoch: 050/050 | Batch 050/891 | Cost: 0.1085\n", + "Epoch: 050/050 | Batch 100/891 | Cost: 0.1356\n", + "Epoch: 050/050 | Batch 150/891 | Cost: 0.0649\n", + "Epoch: 050/050 | Batch 200/891 | Cost: 0.1520\n", + "Epoch: 050/050 | Batch 250/891 | Cost: 0.0987\n", + "Epoch: 050/050 | Batch 300/891 | Cost: 0.1930\n", + "Epoch: 050/050 | Batch 350/891 | Cost: 0.2051\n", + "Epoch: 050/050 | Batch 400/891 | Cost: 0.1187\n", + "Epoch: 050/050 | Batch 450/891 | Cost: 0.0401\n", + "Epoch: 050/050 | Batch 500/891 | Cost: 0.0716\n", + "Epoch: 050/050 | Batch 550/891 | Cost: 0.1372\n", + "Epoch: 050/050 | Batch 600/891 | Cost: 0.1621\n", + "Epoch: 050/050 | Batch 650/891 | Cost: 0.1026\n", + "Epoch: 050/050 | Batch 700/891 | Cost: 0.1087\n", + "Epoch: 050/050 | Batch 750/891 | Cost: 0.1647\n", + "Epoch: 050/050 | Batch 800/891 | Cost: 0.1104\n", + "Epoch: 050/050 | Batch 850/891 | Cost: 0.0536\n", + "training accuracy: 98.72%\n", + "valid accuracy: 91.85%\n", + "Time elapsed: 40.01 min\n", + "Total Training Time: 40.01 min\n", + "Test accuracy: 91.26%\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "\n", + "for epoch in range(NUM_EPOCHS):\n", + " model.train()\n", + " for batch_idx, batch_data in enumerate(train_loader):\n", + " \n", + " text, text_lengths = batch_data.content\n", + " \n", + " ### FORWARD AND BACK PROP\n", + " logits = model(text, text_lengths).squeeze(1)\n", + " cost = F.cross_entropy(logits, batch_data.classlabel.long())\n", + " optimizer.zero_grad()\n", + " \n", + " cost.backward()\n", + " \n", + " ### UPDATE MODEL PARAMETERS\n", + " optimizer.step()\n", + " \n", + " ### LOGGING\n", + " if not batch_idx % 50:\n", + " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", + " f'Batch {batch_idx:03d}/{len(train_loader):03d} | '\n", + " f'Cost: {cost:.4f}')\n", + "\n", + " with torch.set_grad_enabled(False):\n", + " print(f'training accuracy: '\n", + " f'{compute_accuracy(model, train_loader, DEVICE):.2f}%'\n", + " f'\\nvalid accuracy: '\n", + " f'{compute_accuracy(model, valid_loader, DEVICE):.2f}%')\n", + " \n", + " print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')\n", + " \n", + "print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')\n", + "print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}%')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Evaluating on some new text that has been collected from recent news articles and is not part of the training or test sets." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "jt55pscgFdKZ" + }, + "outputs": [], + "source": [ + "import spacy\n", + "nlp = spacy.load('en')\n", + "\n", + "\n", + "map_dictionary = {\n", + " 0: \"World\",\n", + " 1: \"Sports\",\n", + " 2: \"Business\",\n", + " 3:\"Sci/Tech\",\n", + "}\n", + "\n", + "\n", + "def predict_class(model, sentence, min_len=4):\n", + " # Somewhat based on\n", + " # https://github.com/bentrevett/pytorch-sentiment-analysis/\n", + " # blob/master/5%20-%20Multi-class%20Sentiment%20Analysis.ipynb\n", + " model.eval()\n", + " tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n", + " if len(tokenized) < min_len:\n", + " tokenized += [''] * (min_len - len(tokenized))\n", + " indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n", + " length = [len(indexed)]\n", + " tensor = torch.LongTensor(indexed).to(DEVICE)\n", + " tensor = tensor.unsqueeze(1)\n", + " length_tensor = torch.LongTensor(length)\n", + " preds = model(tensor, length_tensor)\n", + " preds = torch.softmax(preds, dim=1)\n", + " \n", + " proba, class_label = preds.max(dim=1)\n", + " return proba.item(), class_label.item()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Class Label: 2 -> Business\n", + "Probability: 0.878576934337616\n" + ] + } + ], + "source": [ + "text = \"\"\"\n", + "The windfall follows a tender offer by Z Holdings, which is controlled by SoftBank’s domestic wireless unit, \n", + "for half of Zozo’s shares this month.\n", + "\"\"\"\n", + "\n", + "proba, pred_label = predict_class(model, text)\n", + "\n", + "print(f'Class Label: {pred_label} -> {map_dictionary[pred_label]}')\n", + "print(f'Probability: {proba}')" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Class Label: 0 -> World\n", + "Probability: 0.9969592094421387\n" + ] + } + ], + "source": [ + "text = \"\"\"\n", + "EU data regulator issues first-ever sanction of an EU institution, \n", + "against the European parliament over its use of US-based NationBuilder to process voter data \n", + "\"\"\"\n", + "\n", + "proba, pred_label = predict_class(model, text)\n", + "\n", + "print(f'Class Label: {pred_label} -> {map_dictionary[pred_label]}')\n", + "print(f'Probability: {proba}')" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Class Label: 2 -> Business\n", + "Probability: 0.9953342080116272\n" + ] + } + ], + "source": [ + "text = \"\"\"\n", + "LG announces CEO Jo Seong-jin will be replaced by Brian Kwon Dec. 1, amid 2020 \n", + "leadership shakeup and LG smartphone division's 18th straight quarterly loss\n", + "\"\"\"\n", + "\n", + "proba, pred_label = predict_class(model, text)\n", + "\n", + "print(f'Class Label: {pred_label} -> {map_dictionary[pred_label]}')\n", + "print(f'Probability: {proba}')" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "7lRusB3dF80X" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "numpy 1.17.2\n", + "pandas 0.24.2\n", + "torch 1.3.0\n", + "torchtext 0.4.0\n", + "spacy 2.2.3\n", + "\n" + ] + } + ], + "source": [ + "%watermark -iv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "rnn_lstm_packed_imdb.ipynb", + "provenance": [], + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_amazon-polarity.ipynb b/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_amazon-polarity.ipynb new file mode 100644 index 0000000..fa87b9a --- /dev/null +++ b/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_amazon-polarity.ipynb @@ -0,0 +1,1247 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", + "- Author: Sebastian Raschka\n", + "- GitHub Repository: https://github.com/rasbt/deeplearning-models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vY4SK0xKAJgm" + }, + "source": [ + "# Bidirectional Multi-layer RNN with LSTM with Own Dataset in CSV Format (Amazon Review Polarity)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Dataset Description\n", + "\n", + "```\n", + "Amazon Review Polarity Dataset\n", + "\n", + "Version 3, Updated 09/09/2015\n", + "\n", + "ORIGIN\n", + "\n", + "The Amazon reviews dataset consists of reviews from amazon. The data span a period of 18 years, including ~35 million reviews up to March 2013. Reviews include product and user information, ratings, and a plaintext review. For more information, please refer to the following paper: J. McAuley and J. Leskovec. Hidden factors and hidden topics: understanding rating dimensions with review text. RecSys, 2013.\n", + "\n", + "The Amazon reviews polarity dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the above dataset. It is used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015).\n", + "\n", + "\n", + "DESCRIPTION\n", + "\n", + "The Amazon reviews polarity dataset is constructed by taking review score 1 and 2 as negative, and 4 and 5 as positive. Samples of score 3 is ignored. In the dataset, class 1 is the negative and class 2 is the positive. Each class has 1,800,000 training samples and 200,000 testing samples.\n", + "\n", + "The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 3 columns in them, corresponding to class index (1 or 2), review title and review text. The review title and text are escaped using double quotes (\"), and any internal double quote is escaped by 2 double quotes (\"\"). New lines are escaped by a backslash followed with an \"n\" character, that is \"\\n\".\n", + "\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "moNmVfuvnImW" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sebastian Raschka \n", + "\n", + "CPython 3.7.3\n", + "IPython 7.9.0\n", + "\n", + "torch 1.3.0\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%watermark -a 'Sebastian Raschka' -v -p torch\n", + "\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torchtext import data\n", + "from torchtext import datasets\n", + "import time\n", + "import random\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "torch.backends.cudnn.deterministic = True" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GSRL42Qgy8I8" + }, + "source": [ + "## General Settings" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "OvW1RgfepCBq" + }, + "outputs": [], + "source": [ + "RANDOM_SEED = 123\n", + "torch.manual_seed(RANDOM_SEED)\n", + "\n", + "VOCABULARY_SIZE = 5000\n", + "LEARNING_RATE = 1e-3\n", + "BATCH_SIZE = 128\n", + "NUM_EPOCHS = 50\n", + "DROPOUT = 0.5\n", + "DEVICE = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "EMBEDDING_DIM = 128\n", + "BIDIRECTIONAL = True\n", + "HIDDEN_DIM = 256\n", + "NUM_LAYERS = 2\n", + "OUTPUT_DIM = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "mQMmKUEisW4W" + }, + "source": [ + "## Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Yelp Review Polarity dataset is available from Xiang Zhang's Google Drive folder at\n", + "\n", + "https://drive.google.com/drive/u/0/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M\n", + "\n", + "From the Google Drive folder, download the file \n", + "\n", + "- `amazon_review_polarity_csv.tar.gz`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "amazon_review_polarity_csv/\n", + "amazon_review_polarity_csv/test.csv\n", + "amazon_review_polarity_csv/train.csv\n", + "amazon_review_polarity_csv/readme.txt\n" + ] + } + ], + "source": [ + "!tar xvzf amazon_review_polarity_csv.tar.gz" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check that the dataset looks okay:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
classlabeltitlecontent
01Stuning even for the non-gamerThis sound track was beautiful! It paints the ...
11The best soundtrack ever to anything.I'm reading a lot of reviews saying that this ...
21Amazing!This soundtrack is my favorite music of all ti...
31Excellent SoundtrackI truly like this soundtrack and I enjoy video...
41Remember, Pull Your Jaw Off The Floor After He...If you've played the game, you know how divine...
\n", + "
" + ], + "text/plain": [ + " classlabel title \\\n", + "0 1 Stuning even for the non-gamer \n", + "1 1 The best soundtrack ever to anything. \n", + "2 1 Amazing! \n", + "3 1 Excellent Soundtrack \n", + "4 1 Remember, Pull Your Jaw Off The Floor After He... \n", + "\n", + " content \n", + "0 This sound track was beautiful! It paints the ... \n", + "1 I'm reading a lot of reviews saying that this ... \n", + "2 This soundtrack is my favorite music of all ti... \n", + "3 I truly like this soundtrack and I enjoy video... \n", + "4 If you've played the game, you know how divine... " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('amazon_review_polarity_csv/train.csv', header=None, index_col=None)\n", + "df.columns = ['classlabel', 'title', 'content']\n", + "df['classlabel'] = df['classlabel']-1\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(df['classlabel'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([1800000, 1800000])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.bincount(df['classlabel'])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "df[['classlabel', 'content']].to_csv('amazon_review_polarity_csv/train_prepocessed.csv', index=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
classlabeltitlecontent
01Great CDMy lovely Pat has one of the GREAT voices of h...
11One of the best game music soundtracks - for a...Despite the fact that I have only played a sma...
20Batteries died within a year ...I bought this charger in Jul 2003 and it worke...
31works fine, but Maha Energy is betterCheck out Maha Energy's website. Their Powerex...
41Great for the non-audiophileReviewed quite a bit of the combo players and ...
\n", + "
" + ], + "text/plain": [ + " classlabel title \\\n", + "0 1 Great CD \n", + "1 1 One of the best game music soundtracks - for a... \n", + "2 0 Batteries died within a year ... \n", + "3 1 works fine, but Maha Energy is better \n", + "4 1 Great for the non-audiophile \n", + "\n", + " content \n", + "0 My lovely Pat has one of the GREAT voices of h... \n", + "1 Despite the fact that I have only played a sma... \n", + "2 I bought this charger in Jul 2003 and it worke... \n", + "3 Check out Maha Energy's website. Their Powerex... \n", + "4 Reviewed quite a bit of the combo players and ... " + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('amazon_review_polarity_csv/test.csv', header=None, index_col=None)\n", + "df.columns = ['classlabel', 'title', 'content']\n", + "df['classlabel'] = df['classlabel']-1\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(df['classlabel'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([200000, 200000])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.bincount(df['classlabel'])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "df[['classlabel', 'content']].to_csv('amazon_review_polarity_csv/test_prepocessed.csv', index=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "del df" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4GnH64XvsV8n" + }, + "source": [ + "Define the Label and Text field formatters:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "TEXT = data.Field(sequential=True,\n", + " tokenize='spacy',\n", + " include_lengths=True) # necessary for packed_padded_sequence\n", + "\n", + "LABEL = data.LabelField(dtype=torch.float)\n", + "\n", + "\n", + "# If you get an error [E050] Can't find model 'en'\n", + "# you need to run the following on your command line:\n", + "# python -m spacy download en" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Process the dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "fields = [('classlabel', LABEL), ('content', TEXT)]\n", + "\n", + "train_dataset = data.TabularDataset(\n", + " path=\"amazon_review_polarity_csv/train_prepocessed.csv\", format='csv',\n", + " skip_header=True, fields=fields)\n", + "\n", + "test_dataset = data.TabularDataset(\n", + " path=\"amazon_review_polarity_csv/test_prepocessed.csv\", format='csv',\n", + " skip_header=True, fields=fields)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Split the training dataset into training and validation:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "colab_type": "code", + "id": "WZ_4jiHVnMxN", + "outputId": "dfa51c04-4845-44c3-f50b-d36d41f132b8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num Train: 3420000\n", + "Num Valid: 180000\n" + ] + } + ], + "source": [ + "train_data, valid_data = train_dataset.split(\n", + " split_ratio=[0.95, 0.05],\n", + " random_state=random.seed(RANDOM_SEED))\n", + "\n", + "print(f'Num Train: {len(train_data)}')\n", + "print(f'Num Valid: {len(valid_data)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "L-TBwKWPslPa" + }, + "source": [ + "Build the vocabulary based on the top \"VOCABULARY_SIZE\" words:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "colab_type": "code", + "id": "e8uNrjdtn4A8", + "outputId": "6cf499d7-7722-4da0-8576-ee0f218cc6e3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vocabulary size: 5002\n", + "Number of classes: 2\n" + ] + } + ], + "source": [ + "TEXT.build_vocab(train_data,\n", + " max_size=VOCABULARY_SIZE,\n", + " vectors='glove.6B.100d',\n", + " unk_init=torch.Tensor.normal_)\n", + "\n", + "LABEL.build_vocab(train_data)\n", + "\n", + "print(f'Vocabulary size: {len(TEXT.vocab)}')\n", + "print(f'Number of classes: {len(LABEL.vocab)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['1', '0']" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(LABEL.vocab.freqs)[-10:]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JpEMNInXtZsb" + }, + "source": [ + "The TEXT.vocab dictionary will contain the word counts and indices. The reason why the number of words is VOCABULARY_SIZE + 2 is that it contains to special tokens for padding and unknown words: `` and ``." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eIQ_zfKLwjKm" + }, + "source": [ + "Make dataset iterators:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "i7JiHR1stHNF" + }, + "outputs": [], + "source": [ + "train_loader, valid_loader, test_loader = data.BucketIterator.splits(\n", + " (train_data, valid_data, test_dataset), \n", + " batch_size=BATCH_SIZE,\n", + " sort_within_batch=True, # necessary for packed_padded_sequence\n", + " sort_key=lambda x: len(x.content),\n", + " device=DEVICE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "R0pT_dMRvicQ" + }, + "source": [ + "Testing the iterators (note that the number of rows depends on the longest document in the respective batch):" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + }, + "colab_type": "code", + "id": "y8SP_FccutT0", + "outputId": "fe33763a-4560-4dee-adee-31cc6c48b0b2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train\n", + "Text matrix size: torch.Size([74, 128])\n", + "Target vector size: torch.Size([128])\n", + "\n", + "Valid:\n", + "Text matrix size: torch.Size([14, 128])\n", + "Target vector size: torch.Size([128])\n", + "\n", + "Test:\n", + "Text matrix size: torch.Size([12, 128])\n", + "Target vector size: torch.Size([128])\n" + ] + } + ], + "source": [ + "print('Train')\n", + "for batch in train_loader:\n", + " print(f'Text matrix size: {batch.content[0].size()}')\n", + " print(f'Target vector size: {batch.classlabel.size()}')\n", + " break\n", + " \n", + "print('\\nValid:')\n", + "for batch in valid_loader:\n", + " print(f'Text matrix size: {batch.content[0].size()}')\n", + " print(f'Target vector size: {batch.classlabel.size()}')\n", + " break\n", + " \n", + "print('\\nTest:')\n", + "for batch in test_loader:\n", + " print(f'Text matrix size: {batch.content[0].size()}')\n", + " print(f'Target vector size: {batch.classlabel.size()}')\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "G_grdW3pxCzz" + }, + "source": [ + "## Model" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "nQIUm5EjxFNa" + }, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "\n", + "class RNN(nn.Module):\n", + " def __init__(self, input_dim, embedding_dim, bidirectional, hidden_dim, num_layers, output_dim, dropout, pad_idx):\n", + " \n", + " super().__init__()\n", + " \n", + " self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx=pad_idx)\n", + " self.rnn = nn.LSTM(embedding_dim, \n", + " hidden_dim,\n", + " num_layers=num_layers,\n", + " bidirectional=bidirectional, \n", + " dropout=dropout)\n", + " self.fc1 = nn.Linear(hidden_dim * num_layers, 64)\n", + " self.fc2 = nn.Linear(64, output_dim)\n", + " self.dropout = nn.Dropout(dropout)\n", + " \n", + " def forward(self, text, text_length):\n", + "\n", + " embedded = self.dropout(self.embedding(text))\n", + " packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_length)\n", + " packed_output, (hidden, cell) = self.rnn(packed_embedded)\n", + " output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)\n", + " hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))\n", + " hidden = self.fc1(hidden)\n", + " return hidden" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ik3NF3faxFmZ" + }, + "outputs": [], + "source": [ + "INPUT_DIM = len(TEXT.vocab)\n", + "\n", + "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n", + "\n", + "torch.manual_seed(RANDOM_SEED)\n", + "model = RNN(INPUT_DIM, EMBEDDING_DIM, BIDIRECTIONAL, HIDDEN_DIM, NUM_LAYERS, OUTPUT_DIM, DROPOUT, PAD_IDX)\n", + "model = model.to(DEVICE)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Lv9Ny9di6VcI" + }, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "T5t1Afn4xO11" + }, + "outputs": [], + "source": [ + "def compute_accuracy(model, data_loader, device):\n", + " model.eval()\n", + " correct_pred, num_examples = 0, 0\n", + " with torch.no_grad():\n", + " for batch_idx, batch_data in enumerate(data_loader):\n", + " text, text_lengths = batch_data.content\n", + " logits = model(text, text_lengths).squeeze(1)\n", + " _, predicted_labels = torch.max(logits, 1)\n", + " num_examples += batch_data.classlabel.size(0)\n", + " correct_pred += (predicted_labels.long() == batch_data.classlabel.long()).sum()\n", + " return correct_pred.float()/num_examples * 100" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1836 + }, + "colab_type": "code", + "id": "EABZM8Vo0ilB", + "outputId": "5d45e293-9909-4588-e793-8dfaf72e5c67" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001/050 | Batch 000/26719 | Cost: 4.1805\n", + "Epoch: 001/050 | Batch 10000/26719 | Cost: 0.2005\n", + "Epoch: 001/050 | Batch 20000/26719 | Cost: 0.1998\n", + "training accuracy: 93.34%\n", + "valid accuracy: 93.27%\n", + "Time elapsed: 33.40 min\n", + "Epoch: 002/050 | Batch 000/26719 | Cost: 0.1659\n", + "Epoch: 002/050 | Batch 10000/26719 | Cost: 0.1326\n", + "Epoch: 002/050 | Batch 20000/26719 | Cost: 0.1470\n", + "training accuracy: 93.82%\n", + "valid accuracy: 93.63%\n", + "Time elapsed: 66.69 min\n", + "Epoch: 003/050 | Batch 000/26719 | Cost: 0.1256\n", + "Epoch: 003/050 | Batch 10000/26719 | Cost: 0.1980\n", + "Epoch: 003/050 | Batch 20000/26719 | Cost: 0.2041\n", + "training accuracy: 93.98%\n", + "valid accuracy: 93.82%\n", + "Time elapsed: 100.02 min\n", + "Epoch: 004/050 | Batch 000/26719 | Cost: 0.2103\n", + "Epoch: 004/050 | Batch 10000/26719 | Cost: 0.1100\n", + "Epoch: 004/050 | Batch 20000/26719 | Cost: 0.1851\n", + "training accuracy: 94.11%\n", + "valid accuracy: 93.93%\n", + "Time elapsed: 133.32 min\n", + "Epoch: 005/050 | Batch 000/26719 | Cost: 0.2196\n", + "Epoch: 005/050 | Batch 10000/26719 | Cost: 0.1209\n", + "Epoch: 005/050 | Batch 20000/26719 | Cost: 0.2147\n", + "training accuracy: 94.13%\n", + "valid accuracy: 93.93%\n", + "Time elapsed: 166.67 min\n", + "Epoch: 006/050 | Batch 000/26719 | Cost: 0.1908\n", + "Epoch: 006/050 | Batch 10000/26719 | Cost: 0.2187\n", + "Epoch: 006/050 | Batch 20000/26719 | Cost: 0.2253\n", + "training accuracy: 94.15%\n", + "valid accuracy: 93.93%\n", + "Time elapsed: 199.87 min\n", + "Epoch: 007/050 | Batch 000/26719 | Cost: 0.1990\n", + "Epoch: 007/050 | Batch 10000/26719 | Cost: 0.1928\n", + "Epoch: 007/050 | Batch 20000/26719 | Cost: 0.2113\n", + "training accuracy: 94.21%\n", + "valid accuracy: 93.97%\n", + "Time elapsed: 233.25 min\n", + "Epoch: 008/050 | Batch 000/26719 | Cost: 0.1753\n", + "Epoch: 008/050 | Batch 10000/26719 | Cost: 0.1708\n", + "Epoch: 008/050 | Batch 20000/26719 | Cost: 0.2158\n", + "training accuracy: 94.21%\n", + "valid accuracy: 93.97%\n", + "Time elapsed: 266.51 min\n", + "Epoch: 009/050 | Batch 000/26719 | Cost: 0.2423\n", + "Epoch: 009/050 | Batch 10000/26719 | Cost: 0.1097\n", + "Epoch: 009/050 | Batch 20000/26719 | Cost: 0.1727\n", + "training accuracy: 94.18%\n", + "valid accuracy: 93.98%\n", + "Time elapsed: 299.86 min\n", + "Epoch: 010/050 | Batch 000/26719 | Cost: 0.1474\n", + "Epoch: 010/050 | Batch 10000/26719 | Cost: 0.2041\n", + "Epoch: 010/050 | Batch 20000/26719 | Cost: 0.1127\n", + "training accuracy: 94.13%\n", + "valid accuracy: 93.91%\n", + "Time elapsed: 333.10 min\n", + "Epoch: 011/050 | Batch 000/26719 | Cost: 0.1643\n", + "Epoch: 011/050 | Batch 10000/26719 | Cost: 0.1772\n", + "Epoch: 011/050 | Batch 20000/26719 | Cost: 0.1586\n", + "training accuracy: 94.13%\n", + "valid accuracy: 93.92%\n", + "Time elapsed: 366.48 min\n", + "Epoch: 012/050 | Batch 000/26719 | Cost: 0.1335\n", + "Epoch: 012/050 | Batch 10000/26719 | Cost: 0.1680\n", + "Epoch: 012/050 | Batch 20000/26719 | Cost: 0.1775\n", + "training accuracy: 94.04%\n", + "valid accuracy: 93.80%\n", + "Time elapsed: 399.85 min\n", + "Epoch: 013/050 | Batch 000/26719 | Cost: 0.1896\n", + "Epoch: 013/050 | Batch 10000/26719 | Cost: 0.0957\n", + "Epoch: 013/050 | Batch 20000/26719 | Cost: 0.1700\n", + "training accuracy: 94.02%\n", + "valid accuracy: 93.80%\n", + "Time elapsed: 432.30 min\n", + "Epoch: 014/050 | Batch 000/26719 | Cost: 0.1370\n", + "Epoch: 014/050 | Batch 10000/26719 | Cost: 0.1449\n", + "Epoch: 014/050 | Batch 20000/26719 | Cost: 0.1874\n", + "training accuracy: 93.96%\n", + "valid accuracy: 93.80%\n", + "Time elapsed: 463.91 min\n", + "Epoch: 015/050 | Batch 000/26719 | Cost: 0.1289\n", + "Epoch: 015/050 | Batch 10000/26719 | Cost: 0.1852\n", + "Epoch: 015/050 | Batch 20000/26719 | Cost: 0.1166\n", + "training accuracy: 93.79%\n", + "valid accuracy: 93.64%\n", + "Time elapsed: 495.59 min\n", + "Epoch: 016/050 | Batch 000/26719 | Cost: 0.1109\n", + "Epoch: 016/050 | Batch 10000/26719 | Cost: 0.1259\n", + "Epoch: 016/050 | Batch 20000/26719 | Cost: 0.1309\n", + "training accuracy: 93.75%\n", + "valid accuracy: 93.58%\n", + "Time elapsed: 527.20 min\n", + "Epoch: 017/050 | Batch 000/26719 | Cost: 0.2273\n", + "Epoch: 017/050 | Batch 10000/26719 | Cost: 0.1037\n", + "Epoch: 017/050 | Batch 20000/26719 | Cost: 0.1274\n", + "training accuracy: 93.58%\n", + "valid accuracy: 93.43%\n", + "Time elapsed: 558.80 min\n", + "Epoch: 018/050 | Batch 000/26719 | Cost: 0.1924\n", + "Epoch: 018/050 | Batch 10000/26719 | Cost: 0.1870\n", + "Epoch: 018/050 | Batch 20000/26719 | Cost: 0.2183\n", + "training accuracy: 93.61%\n", + "valid accuracy: 93.51%\n", + "Time elapsed: 590.48 min\n", + "Epoch: 019/050 | Batch 000/26719 | Cost: 0.1955\n", + "Epoch: 019/050 | Batch 10000/26719 | Cost: 0.1745\n", + "Epoch: 019/050 | Batch 20000/26719 | Cost: 0.1339\n", + "training accuracy: 93.49%\n", + "valid accuracy: 93.43%\n", + "Time elapsed: 622.06 min\n", + "Epoch: 020/050 | Batch 000/26719 | Cost: 0.1498\n", + "Epoch: 020/050 | Batch 10000/26719 | Cost: 0.2582\n", + "Epoch: 020/050 | Batch 20000/26719 | Cost: 0.2263\n", + "training accuracy: 93.41%\n", + "valid accuracy: 93.32%\n", + "Time elapsed: 653.69 min\n", + "Epoch: 021/050 | Batch 000/26719 | Cost: 0.2266\n", + "Epoch: 021/050 | Batch 10000/26719 | Cost: 0.1824\n", + "Epoch: 021/050 | Batch 20000/26719 | Cost: 0.2128\n", + "training accuracy: 93.32%\n", + "valid accuracy: 93.18%\n", + "Time elapsed: 685.43 min\n", + "Epoch: 022/050 | Batch 000/26719 | Cost: 0.1637\n", + "Epoch: 022/050 | Batch 10000/26719 | Cost: 0.2462\n", + "Epoch: 022/050 | Batch 20000/26719 | Cost: 0.1890\n", + "training accuracy: 93.24%\n", + "valid accuracy: 93.13%\n", + "Time elapsed: 716.98 min\n", + "Epoch: 023/050 | Batch 000/26719 | Cost: 0.2072\n", + "Epoch: 023/050 | Batch 10000/26719 | Cost: 0.1904\n", + "Epoch: 023/050 | Batch 20000/26719 | Cost: 0.2408\n", + "training accuracy: 93.13%\n", + "valid accuracy: 93.02%\n", + "Time elapsed: 748.55 min\n", + "Epoch: 024/050 | Batch 000/26719 | Cost: 0.1655\n", + "Epoch: 024/050 | Batch 10000/26719 | Cost: 0.2909\n", + "Epoch: 024/050 | Batch 20000/26719 | Cost: 0.1979\n", + "training accuracy: 93.05%\n", + "valid accuracy: 92.97%\n", + "Time elapsed: 780.21 min\n", + "Epoch: 025/050 | Batch 000/26719 | Cost: 0.1742\n", + "Epoch: 025/050 | Batch 10000/26719 | Cost: 0.2666\n", + "Epoch: 025/050 | Batch 20000/26719 | Cost: 0.2489\n", + "training accuracy: 92.97%\n", + "valid accuracy: 92.84%\n", + "Time elapsed: 811.86 min\n", + "Epoch: 026/050 | Batch 000/26719 | Cost: 0.2000\n", + "Epoch: 026/050 | Batch 10000/26719 | Cost: 0.1438\n", + "Epoch: 026/050 | Batch 20000/26719 | Cost: 0.1771\n", + "training accuracy: 92.80%\n", + "valid accuracy: 92.69%\n", + "Time elapsed: 843.59 min\n", + "Epoch: 027/050 | Batch 000/26719 | Cost: 0.1902\n", + "Epoch: 027/050 | Batch 10000/26719 | Cost: 0.1842\n", + "Epoch: 027/050 | Batch 20000/26719 | Cost: 0.2043\n", + "training accuracy: 92.93%\n", + "valid accuracy: 92.85%\n", + "Time elapsed: 875.26 min\n", + "Epoch: 028/050 | Batch 000/26719 | Cost: 0.1836\n", + "Epoch: 028/050 | Batch 10000/26719 | Cost: 0.1861\n", + "Epoch: 028/050 | Batch 20000/26719 | Cost: 0.1953\n", + "training accuracy: 92.85%\n", + "valid accuracy: 92.76%\n", + "Time elapsed: 906.92 min\n", + "Epoch: 029/050 | Batch 000/26719 | Cost: 0.2089\n", + "Epoch: 029/050 | Batch 10000/26719 | Cost: 0.2378\n", + "Epoch: 029/050 | Batch 20000/26719 | Cost: 0.1476\n", + "training accuracy: 92.84%\n", + "valid accuracy: 92.74%\n", + "Time elapsed: 938.51 min\n", + "Epoch: 030/050 | Batch 000/26719 | Cost: 0.1816\n", + "Epoch: 030/050 | Batch 10000/26719 | Cost: 0.2420\n", + "Epoch: 030/050 | Batch 20000/26719 | Cost: 0.1891\n", + "training accuracy: 92.73%\n", + "valid accuracy: 92.63%\n", + "Time elapsed: 970.14 min\n", + "Epoch: 031/050 | Batch 000/26719 | Cost: 0.1959\n", + "Epoch: 031/050 | Batch 10000/26719 | Cost: 0.2809\n", + "Epoch: 031/050 | Batch 20000/26719 | Cost: 0.2692\n", + "training accuracy: 92.65%\n", + "valid accuracy: 92.63%\n", + "Time elapsed: 1001.72 min\n", + "Epoch: 032/050 | Batch 000/26719 | Cost: 0.1845\n", + "Epoch: 032/050 | Batch 10000/26719 | Cost: 0.2390\n", + "Epoch: 032/050 | Batch 20000/26719 | Cost: 0.1673\n", + "training accuracy: 92.54%\n", + "valid accuracy: 92.50%\n", + "Time elapsed: 1033.34 min\n", + "Epoch: 033/050 | Batch 000/26719 | Cost: 0.1612\n", + "Epoch: 033/050 | Batch 10000/26719 | Cost: 0.2473\n", + "Epoch: 033/050 | Batch 20000/26719 | Cost: 0.2368\n", + "training accuracy: 92.52%\n", + "valid accuracy: 92.43%\n", + "Time elapsed: 1064.98 min\n", + "Epoch: 034/050 | Batch 000/26719 | Cost: 0.1739\n", + "Epoch: 034/050 | Batch 10000/26719 | Cost: 0.2465\n", + "Epoch: 034/050 | Batch 20000/26719 | Cost: 0.2751\n", + "training accuracy: 92.43%\n", + "valid accuracy: 92.35%\n", + "Time elapsed: 1096.60 min\n", + "Epoch: 035/050 | Batch 000/26719 | Cost: 0.1641\n", + "Epoch: 035/050 | Batch 10000/26719 | Cost: 0.2993\n", + "Epoch: 035/050 | Batch 20000/26719 | Cost: 0.2110\n", + "training accuracy: 92.44%\n", + "valid accuracy: 92.38%\n", + "Time elapsed: 1128.23 min\n", + "Epoch: 036/050 | Batch 000/26719 | Cost: 0.1998\n", + "Epoch: 036/050 | Batch 10000/26719 | Cost: 0.4061\n", + "Epoch: 036/050 | Batch 20000/26719 | Cost: 0.3348\n", + "training accuracy: 92.34%\n", + "valid accuracy: 92.23%\n", + "Time elapsed: 1159.86 min\n", + "Epoch: 037/050 | Batch 000/26719 | Cost: 0.2720\n", + "Epoch: 037/050 | Batch 10000/26719 | Cost: 0.1884\n", + "Epoch: 037/050 | Batch 20000/26719 | Cost: 0.2429\n", + "training accuracy: 92.38%\n", + "valid accuracy: 92.35%\n", + "Time elapsed: 1191.48 min\n", + "Epoch: 038/050 | Batch 000/26719 | Cost: 0.1869\n", + "Epoch: 038/050 | Batch 10000/26719 | Cost: 0.3093\n", + "Epoch: 038/050 | Batch 20000/26719 | Cost: 0.2258\n", + "training accuracy: 92.32%\n", + "valid accuracy: 92.33%\n", + "Time elapsed: 1223.13 min\n", + "Epoch: 039/050 | Batch 000/26719 | Cost: 0.2780\n", + "Epoch: 039/050 | Batch 10000/26719 | Cost: 0.2481\n", + "Epoch: 039/050 | Batch 20000/26719 | Cost: 0.2593\n", + "training accuracy: 92.34%\n", + "valid accuracy: 92.31%\n", + "Time elapsed: 1254.79 min\n", + "Epoch: 040/050 | Batch 000/26719 | Cost: 0.1992\n", + "Epoch: 040/050 | Batch 10000/26719 | Cost: 0.2254\n", + "Epoch: 040/050 | Batch 20000/26719 | Cost: 0.2145\n", + "training accuracy: 92.31%\n", + "valid accuracy: 92.25%\n", + "Time elapsed: 1286.39 min\n", + "Epoch: 041/050 | Batch 000/26719 | Cost: 0.1949\n", + "Epoch: 041/050 | Batch 10000/26719 | Cost: 0.2056\n", + "Epoch: 041/050 | Batch 20000/26719 | Cost: 0.2562\n", + "training accuracy: 92.15%\n", + "valid accuracy: 92.10%\n", + "Time elapsed: 1318.01 min\n", + "Epoch: 042/050 | Batch 000/26719 | Cost: 0.2261\n", + "Epoch: 042/050 | Batch 10000/26719 | Cost: 0.2665\n", + "Epoch: 042/050 | Batch 20000/26719 | Cost: 0.2810\n", + "training accuracy: 91.95%\n", + "valid accuracy: 91.88%\n", + "Time elapsed: 1349.75 min\n", + "Epoch: 043/050 | Batch 000/26719 | Cost: 0.2078\n", + "Epoch: 043/050 | Batch 10000/26719 | Cost: 0.2598\n", + "Epoch: 043/050 | Batch 20000/26719 | Cost: 0.2550\n", + "training accuracy: 92.00%\n", + "valid accuracy: 91.96%\n", + "Time elapsed: 1381.34 min\n", + "Epoch: 044/050 | Batch 000/26719 | Cost: 0.1947\n", + "Epoch: 044/050 | Batch 10000/26719 | Cost: 0.2332\n", + "Epoch: 044/050 | Batch 20000/26719 | Cost: 0.3156\n", + "training accuracy: 91.84%\n", + "valid accuracy: 91.83%\n", + "Time elapsed: 1412.81 min\n", + "Epoch: 045/050 | Batch 000/26719 | Cost: 0.2643\n", + "Epoch: 045/050 | Batch 10000/26719 | Cost: 0.2745\n", + "Epoch: 045/050 | Batch 20000/26719 | Cost: 0.3741\n", + "training accuracy: 91.98%\n", + "valid accuracy: 91.94%\n", + "Time elapsed: 1444.41 min\n", + "Epoch: 046/050 | Batch 000/26719 | Cost: 0.2029\n", + "Epoch: 046/050 | Batch 10000/26719 | Cost: 0.2028\n", + "Epoch: 046/050 | Batch 20000/26719 | Cost: 0.2525\n", + "training accuracy: 91.84%\n", + "valid accuracy: 91.86%\n", + "Time elapsed: 1476.07 min\n", + "Epoch: 047/050 | Batch 000/26719 | Cost: 0.2104\n", + "Epoch: 047/050 | Batch 10000/26719 | Cost: 0.1793\n", + "Epoch: 047/050 | Batch 20000/26719 | Cost: 0.2022\n", + "training accuracy: 91.75%\n", + "valid accuracy: 91.73%\n", + "Time elapsed: 1507.73 min\n", + "Epoch: 048/050 | Batch 000/26719 | Cost: 0.3482\n", + "Epoch: 048/050 | Batch 10000/26719 | Cost: 0.2211\n", + "Epoch: 048/050 | Batch 20000/26719 | Cost: 0.2857\n", + "training accuracy: 91.62%\n", + "valid accuracy: 91.56%\n", + "Time elapsed: 1539.42 min\n", + "Epoch: 049/050 | Batch 000/26719 | Cost: 0.2514\n", + "Epoch: 049/050 | Batch 10000/26719 | Cost: 0.2387\n", + "Epoch: 049/050 | Batch 20000/26719 | Cost: 0.2515\n", + "training accuracy: 91.54%\n", + "valid accuracy: 91.47%\n", + "Time elapsed: 1571.06 min\n", + "Epoch: 050/050 | Batch 000/26719 | Cost: 0.2802\n", + "Epoch: 050/050 | Batch 10000/26719 | Cost: 0.3489\n", + "Epoch: 050/050 | Batch 20000/26719 | Cost: 0.2609\n", + "training accuracy: 91.49%\n", + "valid accuracy: 91.40%\n", + "Time elapsed: 1602.62 min\n", + "Total Training Time: 1602.62 min\n", + "Test accuracy: 91.36%\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "\n", + "for epoch in range(NUM_EPOCHS):\n", + " model.train()\n", + " for batch_idx, batch_data in enumerate(train_loader):\n", + " \n", + " text, text_lengths = batch_data.content\n", + " \n", + " ### FORWARD AND BACK PROP\n", + " logits = model(text, text_lengths).squeeze(1)\n", + " cost = F.cross_entropy(logits, batch_data.classlabel.long())\n", + " optimizer.zero_grad()\n", + " \n", + " cost.backward()\n", + " \n", + " ### UPDATE MODEL PARAMETERS\n", + " optimizer.step()\n", + " \n", + " ### LOGGING\n", + " if not batch_idx % 10000:\n", + " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", + " f'Batch {batch_idx:03d}/{len(train_loader):03d} | '\n", + " f'Cost: {cost:.4f}')\n", + "\n", + " with torch.set_grad_enabled(False):\n", + " print(f'training accuracy: '\n", + " f'{compute_accuracy(model, train_loader, DEVICE):.2f}%'\n", + " f'\\nvalid accuracy: '\n", + " f'{compute_accuracy(model, valid_loader, DEVICE):.2f}%')\n", + " \n", + " print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')\n", + " \n", + "print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')\n", + "print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}%')" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "7lRusB3dF80X" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "spacy 2.2.3\n", + "pandas 0.24.2\n", + "torchtext 0.4.0\n", + "numpy 1.17.2\n", + "torch 1.3.0\n", + "\n" + ] + } + ], + "source": [ + "%watermark -iv" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(model.state_dict(), 'rnn_bi_multilayer_lstm_own_csv_amazon-polarity.pt')" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "rnn_lstm_packed_imdb.ipynb", + "provenance": [], + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_yelp-polarity.ipynb b/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_yelp-polarity.ipynb new file mode 100644 index 0000000..40e9621 --- /dev/null +++ b/pytorch_ipynb/rnn/rnn_bi_multilayer_lstm_own_csv_yelp-polarity.ipynb @@ -0,0 +1,1451 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Deep Learning Models -- A collection of various deep learning architectures, models, and tips for TensorFlow and PyTorch in Jupyter Notebooks.\n", + "- Author: Sebastian Raschka\n", + "- GitHub Repository: https://github.com/rasbt/deeplearning-models" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "vY4SK0xKAJgm" + }, + "source": [ + "# Bidirectional Multi-layer RNN with LSTM with Own Dataset in CSV Format (Yelp Review Polarity)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Dataset Description\n", + "\n", + "```\n", + "Yelp Review Polarity Dataset\n", + "\n", + "Version 1, Updated 09/09/2015\n", + "\n", + "ORIGIN\n", + "\n", + "The Yelp reviews dataset consists of reviews from Yelp. It is extracted from the Yelp Dataset Challenge 2015 data. For more information, please refer to http://www.yelp.com/dataset_challenge\n", + "\n", + "The Yelp reviews polarity dataset is constructed by Xiang Zhang (xiang.zhang@nyu.edu) from the above dataset. It is first used as a text classification benchmark in the following paper: Xiang Zhang, Junbo Zhao, Yann LeCun. Character-level Convolutional Networks for Text Classification. Advances in Neural Information Processing Systems 28 (NIPS 2015).\n", + "\n", + "\n", + "DESCRIPTION\n", + "\n", + "The Yelp reviews polarity dataset is constructed by considering stars 1 and 2 negative, and 3 and 4 positive. For each polarity 280,000 training samples and 19,000 testing samples are take randomly. In total there are 560,000 trainig samples and 38,000 testing samples. Negative polarity is class 1, and positive class 2.\n", + "\n", + "The files train.csv and test.csv contain all the training samples as comma-sparated values. There are 2 columns in them, corresponding to class index (1 and 2) and review text. The review texts are escaped using double quotes (\"), and any internal double quote is escaped by 2 double quotes (\"\"). New lines are escaped by a backslash followed with an \"n\" character, that is \"\\n\".backslash followed with an \"n\" character, that is \"\\n\".\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "moNmVfuvnImW" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Sebastian Raschka \n", + "\n", + "CPython 3.7.3\n", + "IPython 7.9.0\n", + "\n", + "torch 1.3.0\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%watermark -a 'Sebastian Raschka' -v -p torch\n", + "\n", + "\n", + "import torch\n", + "import torch.nn.functional as F\n", + "from torchtext import data\n", + "from torchtext import datasets\n", + "import time\n", + "import random\n", + "import pandas as pd\n", + "import numpy as np\n", + "\n", + "torch.backends.cudnn.deterministic = True" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "GSRL42Qgy8I8" + }, + "source": [ + "## General Settings" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "OvW1RgfepCBq" + }, + "outputs": [], + "source": [ + "RANDOM_SEED = 123\n", + "torch.manual_seed(RANDOM_SEED)\n", + "\n", + "VOCABULARY_SIZE = 5000\n", + "LEARNING_RATE = 1e-3\n", + "BATCH_SIZE = 128\n", + "NUM_EPOCHS = 50\n", + "DROPOUT = 0.5\n", + "DEVICE = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')\n", + "\n", + "EMBEDDING_DIM = 128\n", + "BIDIRECTIONAL = True\n", + "HIDDEN_DIM = 256\n", + "NUM_LAYERS = 2\n", + "OUTPUT_DIM = 2" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "mQMmKUEisW4W" + }, + "source": [ + "## Dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The Yelp Review Polarity dataset is available from Xiang Zhang's Google Drive folder at\n", + "\n", + "https://drive.google.com/drive/u/0/folders/0Bz8a_Dbh9Qhbfll6bVpmNUtUcFdjYmF2SEpmZUZUcVNiMUw1TWN6RDV3a0JHT3kxLVhVR2M\n", + "\n", + "From the Google Drive folder, download the file \n", + "\n", + "- `yelp_review_polarity_csv.tar.gz`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "yelp_review_polarity_csv/\n", + "yelp_review_polarity_csv/readme.txt\n", + "yelp_review_polarity_csv/test.csv\n", + "yelp_review_polarity_csv/train.csv\n" + ] + } + ], + "source": [ + "!tar xvzf yelp_review_polarity_csv.tar.gz" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Check that the dataset looks okay:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
classlabelcontent
00Unfortunately, the frustration of being Dr. Go...
11Been going to Dr. Goldberg for over 10 years. ...
20I don't know what Dr. Goldberg was like before...
30I'm writing this review to give you a heads up...
41All the food is great here. But the best thing...
\n", + "
" + ], + "text/plain": [ + " classlabel content\n", + "0 0 Unfortunately, the frustration of being Dr. Go...\n", + "1 1 Been going to Dr. Goldberg for over 10 years. ...\n", + "2 0 I don't know what Dr. Goldberg was like before...\n", + "3 0 I'm writing this review to give you a heads up...\n", + "4 1 All the food is great here. But the best thing..." + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('yelp_review_polarity_csv/train.csv', header=None, index_col=None)\n", + "df.columns = ['classlabel', 'content']\n", + "df['classlabel'] = df['classlabel']-1\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1])" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(df['classlabel'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([280000, 280000])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.bincount(df['classlabel'])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "df[['classlabel', 'content']].to_csv('yelp_review_polarity_csv/train_prepocessed.csv', index=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
classlabelcontent
01Contrary to other reviews, I have zero complai...
10Last summer I had an appointment to get new ti...
21Friendly staff, same starbucks fair you get an...
30The food is good. Unfortunately the service is...
41Even when we didn't have a car Filene's Baseme...
\n", + "
" + ], + "text/plain": [ + " classlabel content\n", + "0 1 Contrary to other reviews, I have zero complai...\n", + "1 0 Last summer I had an appointment to get new ti...\n", + "2 1 Friendly staff, same starbucks fair you get an...\n", + "3 0 The food is good. Unfortunately the service is...\n", + "4 1 Even when we didn't have a car Filene's Baseme..." + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "df = pd.read_csv('yelp_review_polarity_csv/test.csv', header=None, index_col=None)\n", + "df.columns = ['classlabel', 'content']\n", + "df['classlabel'] = df['classlabel']-1\n", + "df.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.unique(df['classlabel'].values)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([19000, 19000])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.bincount(df['classlabel'])" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "df[['classlabel', 'content']].to_csv('yelp_review_polarity_csv/test_prepocessed.csv', index=None)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "del df" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "4GnH64XvsV8n" + }, + "source": [ + "Define the Label and Text field formatters:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "TEXT = data.Field(sequential=True,\n", + " tokenize='spacy',\n", + " include_lengths=True) # necessary for packed_padded_sequence\n", + "\n", + "LABEL = data.LabelField(dtype=torch.float)\n", + "\n", + "\n", + "# If you get an error [E050] Can't find model 'en'\n", + "# you need to run the following on your command line:\n", + "# python -m spacy download en" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Process the dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "fields = [('classlabel', LABEL), ('content', TEXT)]\n", + "\n", + "train_dataset = data.TabularDataset(\n", + " path=\"yelp_review_polarity_csv/train_prepocessed.csv\", format='csv',\n", + " skip_header=True, fields=fields)\n", + "\n", + "test_dataset = data.TabularDataset(\n", + " path=\"yelp_review_polarity_csv/test_prepocessed.csv\", format='csv',\n", + " skip_header=True, fields=fields)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Split the training dataset into training and validation:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "colab_type": "code", + "id": "WZ_4jiHVnMxN", + "outputId": "dfa51c04-4845-44c3-f50b-d36d41f132b8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Num Train: 532000\n", + "Num Valid: 28000\n" + ] + } + ], + "source": [ + "train_data, valid_data = train_dataset.split(\n", + " split_ratio=[0.95, 0.05],\n", + " random_state=random.seed(RANDOM_SEED))\n", + "\n", + "print(f'Num Train: {len(train_data)}')\n", + "print(f'Num Valid: {len(valid_data)}')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "L-TBwKWPslPa" + }, + "source": [ + "Build the vocabulary based on the top \"VOCABULARY_SIZE\" words:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 51 + }, + "colab_type": "code", + "id": "e8uNrjdtn4A8", + "outputId": "6cf499d7-7722-4da0-8576-ee0f218cc6e3" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vocabulary size: 5002\n", + "Number of classes: 2\n" + ] + } + ], + "source": [ + "TEXT.build_vocab(train_data,\n", + " max_size=VOCABULARY_SIZE,\n", + " vectors='glove.6B.100d',\n", + " unk_init=torch.Tensor.normal_)\n", + "\n", + "LABEL.build_vocab(train_data)\n", + "\n", + "print(f'Vocabulary size: {len(TEXT.vocab)}')\n", + "print(f'Number of classes: {len(LABEL.vocab)}')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['1', '0']" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(LABEL.vocab.freqs)[-10:]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "JpEMNInXtZsb" + }, + "source": [ + "The TEXT.vocab dictionary will contain the word counts and indices. The reason why the number of words is VOCABULARY_SIZE + 2 is that it contains to special tokens for padding and unknown words: `` and ``." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "eIQ_zfKLwjKm" + }, + "source": [ + "Make dataset iterators:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "i7JiHR1stHNF" + }, + "outputs": [], + "source": [ + "train_loader, valid_loader, test_loader = data.BucketIterator.splits(\n", + " (train_data, valid_data, test_dataset), \n", + " batch_size=BATCH_SIZE,\n", + " sort_within_batch=True, # necessary for packed_padded_sequence\n", + " sort_key=lambda x: len(x.content),\n", + " device=DEVICE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "R0pT_dMRvicQ" + }, + "source": [ + "Testing the iterators (note that the number of rows depends on the longest document in the respective batch):" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 204 + }, + "colab_type": "code", + "id": "y8SP_FccutT0", + "outputId": "fe33763a-4560-4dee-adee-31cc6c48b0b2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train\n", + "Text matrix size: torch.Size([113, 128])\n", + "Target vector size: torch.Size([128])\n", + "\n", + "Valid:\n", + "Text matrix size: torch.Size([6, 128])\n", + "Target vector size: torch.Size([128])\n", + "\n", + "Test:\n", + "Text matrix size: torch.Size([5, 128])\n", + "Target vector size: torch.Size([128])\n" + ] + } + ], + "source": [ + "print('Train')\n", + "for batch in train_loader:\n", + " print(f'Text matrix size: {batch.content[0].size()}')\n", + " print(f'Target vector size: {batch.classlabel.size()}')\n", + " break\n", + " \n", + "print('\\nValid:')\n", + "for batch in valid_loader:\n", + " print(f'Text matrix size: {batch.content[0].size()}')\n", + " print(f'Target vector size: {batch.classlabel.size()}')\n", + " break\n", + " \n", + "print('\\nTest:')\n", + "for batch in test_loader:\n", + " print(f'Text matrix size: {batch.content[0].size()}')\n", + " print(f'Target vector size: {batch.classlabel.size()}')\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "G_grdW3pxCzz" + }, + "source": [ + "## Model" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "nQIUm5EjxFNa" + }, + "outputs": [], + "source": [ + "import torch.nn as nn\n", + "\n", + "\n", + "class RNN(nn.Module):\n", + " def __init__(self, input_dim, embedding_dim, bidirectional, hidden_dim, num_layers, output_dim, dropout, pad_idx):\n", + " \n", + " super().__init__()\n", + " \n", + " self.embedding = nn.Embedding(input_dim, embedding_dim, padding_idx=pad_idx)\n", + " self.rnn = nn.LSTM(embedding_dim, \n", + " hidden_dim,\n", + " num_layers=num_layers,\n", + " bidirectional=bidirectional, \n", + " dropout=dropout)\n", + " self.fc1 = nn.Linear(hidden_dim * num_layers, 64)\n", + " self.fc2 = nn.Linear(64, output_dim)\n", + " self.dropout = nn.Dropout(dropout)\n", + " \n", + " def forward(self, text, text_length):\n", + "\n", + " embedded = self.dropout(self.embedding(text))\n", + " packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_length)\n", + " packed_output, (hidden, cell) = self.rnn(packed_embedded)\n", + " output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)\n", + " hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1))\n", + " hidden = self.fc1(hidden)\n", + " return hidden" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "Ik3NF3faxFmZ" + }, + "outputs": [], + "source": [ + "INPUT_DIM = len(TEXT.vocab)\n", + "\n", + "PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]\n", + "\n", + "torch.manual_seed(RANDOM_SEED)\n", + "model = RNN(INPUT_DIM, EMBEDDING_DIM, BIDIRECTIONAL, HIDDEN_DIM, NUM_LAYERS, OUTPUT_DIM, DROPOUT, PAD_IDX)\n", + "model = model.to(DEVICE)\n", + "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Lv9Ny9di6VcI" + }, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "T5t1Afn4xO11" + }, + "outputs": [], + "source": [ + "def compute_accuracy(model, data_loader, device):\n", + " model.eval()\n", + " correct_pred, num_examples = 0, 0\n", + " with torch.no_grad():\n", + " for batch_idx, batch_data in enumerate(data_loader):\n", + " text, text_lengths = batch_data.content\n", + " logits = model(text, text_lengths).squeeze(1)\n", + " _, predicted_labels = torch.max(logits, 1)\n", + " num_examples += batch_data.classlabel.size(0)\n", + " correct_pred += (predicted_labels.long() == batch_data.classlabel.long()).sum()\n", + " return correct_pred.float()/num_examples * 100" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1836 + }, + "colab_type": "code", + "id": "EABZM8Vo0ilB", + "outputId": "5d45e293-9909-4588-e793-8dfaf72e5c67" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch: 001/050 | Batch 000/4157 | Cost: 4.1925\n", + "Epoch: 001/050 | Batch 1000/4157 | Cost: 0.3392\n", + "Epoch: 001/050 | Batch 2000/4157 | Cost: 0.3254\n", + "Epoch: 001/050 | Batch 3000/4157 | Cost: 0.3263\n", + "Epoch: 001/050 | Batch 4000/4157 | Cost: 0.1488\n", + "training accuracy: 94.50%\n", + "valid accuracy: 94.12%\n", + "Time elapsed: 8.57 min\n", + "Epoch: 002/050 | Batch 000/4157 | Cost: 0.2246\n", + "Epoch: 002/050 | Batch 1000/4157 | Cost: 0.1248\n", + "Epoch: 002/050 | Batch 2000/4157 | Cost: 0.1107\n", + "Epoch: 002/050 | Batch 3000/4157 | Cost: 0.1820\n", + "Epoch: 002/050 | Batch 4000/4157 | Cost: 0.0808\n", + "training accuracy: 95.75%\n", + "valid accuracy: 95.35%\n", + "Time elapsed: 17.23 min\n", + "Epoch: 003/050 | Batch 000/4157 | Cost: 0.0877\n", + "Epoch: 003/050 | Batch 1000/4157 | Cost: 0.0720\n", + "Epoch: 003/050 | Batch 2000/4157 | Cost: 0.0770\n", + "Epoch: 003/050 | Batch 3000/4157 | Cost: 0.0876\n", + "Epoch: 003/050 | Batch 4000/4157 | Cost: 0.0851\n", + "training accuracy: 96.15%\n", + "valid accuracy: 95.62%\n", + "Time elapsed: 25.90 min\n", + "Epoch: 004/050 | Batch 000/4157 | Cost: 0.1596\n", + "Epoch: 004/050 | Batch 1000/4157 | Cost: 0.1571\n", + "Epoch: 004/050 | Batch 2000/4157 | Cost: 0.1728\n", + "Epoch: 004/050 | Batch 3000/4157 | Cost: 0.0911\n", + "Epoch: 004/050 | Batch 4000/4157 | Cost: 0.1380\n", + "training accuracy: 96.46%\n", + "valid accuracy: 95.86%\n", + "Time elapsed: 34.65 min\n", + "Epoch: 005/050 | Batch 000/4157 | Cost: 0.2183\n", + "Epoch: 005/050 | Batch 1000/4157 | Cost: 0.0951\n", + "Epoch: 005/050 | Batch 2000/4157 | Cost: 0.1052\n", + "Epoch: 005/050 | Batch 3000/4157 | Cost: 0.0759\n", + "Epoch: 005/050 | Batch 4000/4157 | Cost: 0.0705\n", + "training accuracy: 96.57%\n", + "valid accuracy: 95.69%\n", + "Time elapsed: 43.71 min\n", + "Epoch: 006/050 | Batch 000/4157 | Cost: 0.1320\n", + "Epoch: 006/050 | Batch 1000/4157 | Cost: 0.0989\n", + "Epoch: 006/050 | Batch 2000/4157 | Cost: 0.1763\n", + "Epoch: 006/050 | Batch 3000/4157 | Cost: 0.1935\n", + "Epoch: 006/050 | Batch 4000/4157 | Cost: 0.1201\n", + "training accuracy: 96.80%\n", + "valid accuracy: 95.91%\n", + "Time elapsed: 52.72 min\n", + "Epoch: 007/050 | Batch 000/4157 | Cost: 0.1282\n", + "Epoch: 007/050 | Batch 1000/4157 | Cost: 0.0945\n", + "Epoch: 007/050 | Batch 2000/4157 | Cost: 0.1035\n", + "Epoch: 007/050 | Batch 3000/4157 | Cost: 0.0490\n", + "Epoch: 007/050 | Batch 4000/4157 | Cost: 0.1134\n", + "training accuracy: 96.97%\n", + "valid accuracy: 96.08%\n", + "Time elapsed: 61.70 min\n", + "Epoch: 008/050 | Batch 000/4157 | Cost: 0.0646\n", + "Epoch: 008/050 | Batch 1000/4157 | Cost: 0.0576\n", + "Epoch: 008/050 | Batch 2000/4157 | Cost: 0.0668\n", + "Epoch: 008/050 | Batch 3000/4157 | Cost: 0.1527\n", + "Epoch: 008/050 | Batch 4000/4157 | Cost: 0.0996\n", + "training accuracy: 97.05%\n", + "valid accuracy: 96.15%\n", + "Time elapsed: 70.65 min\n", + "Epoch: 009/050 | Batch 000/4157 | Cost: 0.1095\n", + "Epoch: 009/050 | Batch 1000/4157 | Cost: 0.1356\n", + "Epoch: 009/050 | Batch 2000/4157 | Cost: 0.0523\n", + "Epoch: 009/050 | Batch 3000/4157 | Cost: 0.0761\n", + "Epoch: 009/050 | Batch 4000/4157 | Cost: 0.0700\n", + "training accuracy: 97.11%\n", + "valid accuracy: 96.09%\n", + "Time elapsed: 79.68 min\n", + "Epoch: 010/050 | Batch 000/4157 | Cost: 0.0975\n", + "Epoch: 010/050 | Batch 1000/4157 | Cost: 0.1032\n", + "Epoch: 010/050 | Batch 2000/4157 | Cost: 0.1357\n", + "Epoch: 010/050 | Batch 3000/4157 | Cost: 0.0950\n", + "Epoch: 010/050 | Batch 4000/4157 | Cost: 0.1263\n", + "training accuracy: 97.11%\n", + "valid accuracy: 96.06%\n", + "Time elapsed: 88.72 min\n", + "Epoch: 011/050 | Batch 000/4157 | Cost: 0.0440\n", + "Epoch: 011/050 | Batch 1000/4157 | Cost: 0.0980\n", + "Epoch: 011/050 | Batch 2000/4157 | Cost: 0.0603\n", + "Epoch: 011/050 | Batch 3000/4157 | Cost: 0.0524\n", + "Epoch: 011/050 | Batch 4000/4157 | Cost: 0.0840\n", + "training accuracy: 97.29%\n", + "valid accuracy: 96.21%\n", + "Time elapsed: 97.68 min\n", + "Epoch: 012/050 | Batch 000/4157 | Cost: 0.1569\n", + "Epoch: 012/050 | Batch 1000/4157 | Cost: 0.0744\n", + "Epoch: 012/050 | Batch 2000/4157 | Cost: 0.1388\n", + "Epoch: 012/050 | Batch 3000/4157 | Cost: 0.0720\n", + "Epoch: 012/050 | Batch 4000/4157 | Cost: 0.0588\n", + "training accuracy: 97.24%\n", + "valid accuracy: 96.15%\n", + "Time elapsed: 106.73 min\n", + "Epoch: 013/050 | Batch 000/4157 | Cost: 0.0353\n", + "Epoch: 013/050 | Batch 1000/4157 | Cost: 0.1184\n", + "Epoch: 013/050 | Batch 2000/4157 | Cost: 0.0866\n", + "Epoch: 013/050 | Batch 3000/4157 | Cost: 0.0525\n", + "Epoch: 013/050 | Batch 4000/4157 | Cost: 0.0722\n", + "training accuracy: 97.13%\n", + "valid accuracy: 95.86%\n", + "Time elapsed: 115.74 min\n", + "Epoch: 014/050 | Batch 000/4157 | Cost: 0.0898\n", + "Epoch: 014/050 | Batch 1000/4157 | Cost: 0.0936\n", + "Epoch: 014/050 | Batch 2000/4157 | Cost: 0.0786\n", + "Epoch: 014/050 | Batch 3000/4157 | Cost: 0.0615\n", + "Epoch: 014/050 | Batch 4000/4157 | Cost: 0.1044\n", + "training accuracy: 97.33%\n", + "valid accuracy: 96.11%\n", + "Time elapsed: 124.77 min\n", + "Epoch: 015/050 | Batch 000/4157 | Cost: 0.1224\n", + "Epoch: 015/050 | Batch 1000/4157 | Cost: 0.0771\n", + "Epoch: 015/050 | Batch 2000/4157 | Cost: 0.1181\n", + "Epoch: 015/050 | Batch 3000/4157 | Cost: 0.0447\n", + "Epoch: 015/050 | Batch 4000/4157 | Cost: 0.0996\n", + "training accuracy: 97.39%\n", + "valid accuracy: 96.10%\n", + "Time elapsed: 133.71 min\n", + "Epoch: 016/050 | Batch 000/4157 | Cost: 0.0977\n", + "Epoch: 016/050 | Batch 1000/4157 | Cost: 0.1531\n", + "Epoch: 016/050 | Batch 2000/4157 | Cost: 0.0744\n", + "Epoch: 016/050 | Batch 3000/4157 | Cost: 0.0793\n", + "Epoch: 016/050 | Batch 4000/4157 | Cost: 0.0540\n", + "training accuracy: 97.54%\n", + "valid accuracy: 96.31%\n", + "Time elapsed: 142.78 min\n", + "Epoch: 017/050 | Batch 000/4157 | Cost: 0.1054\n", + "Epoch: 017/050 | Batch 1000/4157 | Cost: 0.0698\n", + "Epoch: 017/050 | Batch 2000/4157 | Cost: 0.0439\n", + "Epoch: 017/050 | Batch 3000/4157 | Cost: 0.0602\n", + "Epoch: 017/050 | Batch 4000/4157 | Cost: 0.0843\n", + "training accuracy: 97.41%\n", + "valid accuracy: 96.08%\n", + "Time elapsed: 151.83 min\n", + "Epoch: 018/050 | Batch 000/4157 | Cost: 0.1025\n", + "Epoch: 018/050 | Batch 1000/4157 | Cost: 0.1091\n", + "Epoch: 018/050 | Batch 2000/4157 | Cost: 0.0359\n", + "Epoch: 018/050 | Batch 3000/4157 | Cost: 0.0509\n", + "Epoch: 018/050 | Batch 4000/4157 | Cost: 0.0674\n", + "training accuracy: 97.50%\n", + "valid accuracy: 96.15%\n", + "Time elapsed: 160.86 min\n", + "Epoch: 019/050 | Batch 000/4157 | Cost: 0.0795\n", + "Epoch: 019/050 | Batch 1000/4157 | Cost: 0.0561\n", + "Epoch: 019/050 | Batch 2000/4157 | Cost: 0.0533\n", + "Epoch: 019/050 | Batch 3000/4157 | Cost: 0.0801\n", + "Epoch: 019/050 | Batch 4000/4157 | Cost: 0.1394\n", + "training accuracy: 97.60%\n", + "valid accuracy: 96.19%\n", + "Time elapsed: 169.83 min\n", + "Epoch: 020/050 | Batch 000/4157 | Cost: 0.0896\n", + "Epoch: 020/050 | Batch 1000/4157 | Cost: 0.1357\n", + "Epoch: 020/050 | Batch 2000/4157 | Cost: 0.0574\n", + "Epoch: 020/050 | Batch 3000/4157 | Cost: 0.0695\n", + "Epoch: 020/050 | Batch 4000/4157 | Cost: 0.0781\n", + "training accuracy: 97.56%\n", + "valid accuracy: 96.16%\n", + "Time elapsed: 178.88 min\n", + "Epoch: 021/050 | Batch 000/4157 | Cost: 0.1040\n", + "Epoch: 021/050 | Batch 1000/4157 | Cost: 0.0993\n", + "Epoch: 021/050 | Batch 2000/4157 | Cost: 0.0427\n", + "Epoch: 021/050 | Batch 3000/4157 | Cost: 0.1151\n", + "Epoch: 021/050 | Batch 4000/4157 | Cost: 0.0666\n", + "training accuracy: 97.60%\n", + "valid accuracy: 96.14%\n", + "Time elapsed: 187.91 min\n", + "Epoch: 022/050 | Batch 000/4157 | Cost: 0.0760\n", + "Epoch: 022/050 | Batch 1000/4157 | Cost: 0.0557\n", + "Epoch: 022/050 | Batch 2000/4157 | Cost: 0.0538\n", + "Epoch: 022/050 | Batch 3000/4157 | Cost: 0.0619\n", + "Epoch: 022/050 | Batch 4000/4157 | Cost: 0.0884\n", + "training accuracy: 97.55%\n", + "valid accuracy: 96.16%\n", + "Time elapsed: 196.92 min\n", + "Epoch: 023/050 | Batch 000/4157 | Cost: 0.0938\n", + "Epoch: 023/050 | Batch 1000/4157 | Cost: 0.0543\n", + "Epoch: 023/050 | Batch 2000/4157 | Cost: 0.0295\n", + "Epoch: 023/050 | Batch 3000/4157 | Cost: 0.1257\n", + "Epoch: 023/050 | Batch 4000/4157 | Cost: 0.0690\n", + "training accuracy: 97.54%\n", + "valid accuracy: 96.19%\n", + "Time elapsed: 205.98 min\n", + "Epoch: 024/050 | Batch 000/4157 | Cost: 0.0709\n", + "Epoch: 024/050 | Batch 1000/4157 | Cost: 0.0676\n", + "Epoch: 024/050 | Batch 2000/4157 | Cost: 0.1822\n", + "Epoch: 024/050 | Batch 3000/4157 | Cost: 0.0687\n", + "Epoch: 024/050 | Batch 4000/4157 | Cost: 0.0737\n", + "training accuracy: 97.68%\n", + "valid accuracy: 96.28%\n", + "Time elapsed: 215.04 min\n", + "Epoch: 025/050 | Batch 000/4157 | Cost: 0.0740\n", + "Epoch: 025/050 | Batch 1000/4157 | Cost: 0.0932\n", + "Epoch: 025/050 | Batch 2000/4157 | Cost: 0.1179\n", + "Epoch: 025/050 | Batch 3000/4157 | Cost: 0.0735\n", + "Epoch: 025/050 | Batch 4000/4157 | Cost: 0.1019\n", + "training accuracy: 97.68%\n", + "valid accuracy: 96.25%\n", + "Time elapsed: 224.07 min\n", + "Epoch: 026/050 | Batch 000/4157 | Cost: 0.0893\n", + "Epoch: 026/050 | Batch 1000/4157 | Cost: 0.0890\n", + "Epoch: 026/050 | Batch 2000/4157 | Cost: 0.0736\n", + "Epoch: 026/050 | Batch 3000/4157 | Cost: 0.0675\n", + "Epoch: 026/050 | Batch 4000/4157 | Cost: 0.0344\n", + "training accuracy: 97.62%\n", + "valid accuracy: 96.23%\n", + "Time elapsed: 233.00 min\n", + "Epoch: 027/050 | Batch 000/4157 | Cost: 0.0331\n", + "Epoch: 027/050 | Batch 1000/4157 | Cost: 0.1079\n", + "Epoch: 027/050 | Batch 2000/4157 | Cost: 0.0800\n", + "Epoch: 027/050 | Batch 3000/4157 | Cost: 0.0703\n", + "Epoch: 027/050 | Batch 4000/4157 | Cost: 0.0759\n", + "training accuracy: 97.62%\n", + "valid accuracy: 96.11%\n", + "Time elapsed: 242.07 min\n", + "Epoch: 028/050 | Batch 000/4157 | Cost: 0.1071\n", + "Epoch: 028/050 | Batch 1000/4157 | Cost: 0.0826\n", + "Epoch: 028/050 | Batch 2000/4157 | Cost: 0.0699\n", + "Epoch: 028/050 | Batch 3000/4157 | Cost: 0.0783\n", + "Epoch: 028/050 | Batch 4000/4157 | Cost: 0.0550\n", + "training accuracy: 97.55%\n", + "valid accuracy: 96.09%\n", + "Time elapsed: 251.10 min\n", + "Epoch: 029/050 | Batch 000/4157 | Cost: 0.0291\n", + "Epoch: 029/050 | Batch 1000/4157 | Cost: 0.0881\n", + "Epoch: 029/050 | Batch 2000/4157 | Cost: 0.0537\n", + "Epoch: 029/050 | Batch 3000/4157 | Cost: 0.1502\n", + "Epoch: 029/050 | Batch 4000/4157 | Cost: 0.0614\n", + "training accuracy: 97.68%\n", + "valid accuracy: 96.20%\n", + "Time elapsed: 260.10 min\n", + "Epoch: 030/050 | Batch 000/4157 | Cost: 0.0922\n", + "Epoch: 030/050 | Batch 1000/4157 | Cost: 0.1103\n", + "Epoch: 030/050 | Batch 2000/4157 | Cost: 0.0814\n", + "Epoch: 030/050 | Batch 3000/4157 | Cost: 0.0506\n", + "Epoch: 030/050 | Batch 4000/4157 | Cost: 0.1734\n", + "training accuracy: 97.69%\n", + "valid accuracy: 96.13%\n", + "Time elapsed: 269.04 min\n", + "Epoch: 031/050 | Batch 000/4157 | Cost: 0.1000\n", + "Epoch: 031/050 | Batch 1000/4157 | Cost: 0.0227\n", + "Epoch: 031/050 | Batch 2000/4157 | Cost: 0.1718\n", + "Epoch: 031/050 | Batch 3000/4157 | Cost: 0.0873\n", + "Epoch: 031/050 | Batch 4000/4157 | Cost: 0.0753\n", + "training accuracy: 97.67%\n", + "valid accuracy: 96.17%\n", + "Time elapsed: 278.07 min\n", + "Epoch: 032/050 | Batch 000/4157 | Cost: 0.0953\n", + "Epoch: 032/050 | Batch 1000/4157 | Cost: 0.0244\n", + "Epoch: 032/050 | Batch 2000/4157 | Cost: 0.0515\n", + "Epoch: 032/050 | Batch 3000/4157 | Cost: 0.0968\n", + "Epoch: 032/050 | Batch 4000/4157 | Cost: 0.0896\n", + "training accuracy: 97.67%\n", + "valid accuracy: 96.23%\n", + "Time elapsed: 287.10 min\n", + "Epoch: 033/050 | Batch 000/4157 | Cost: 0.0858\n", + "Epoch: 033/050 | Batch 1000/4157 | Cost: 0.0686\n", + "Epoch: 033/050 | Batch 2000/4157 | Cost: 0.0543\n", + "Epoch: 033/050 | Batch 3000/4157 | Cost: 0.0806\n", + "Epoch: 033/050 | Batch 4000/4157 | Cost: 0.0895\n", + "training accuracy: 97.66%\n", + "valid accuracy: 96.15%\n", + "Time elapsed: 296.08 min\n", + "Epoch: 034/050 | Batch 000/4157 | Cost: 0.0978\n", + "Epoch: 034/050 | Batch 1000/4157 | Cost: 0.1026\n", + "Epoch: 034/050 | Batch 2000/4157 | Cost: 0.0278\n", + "Epoch: 034/050 | Batch 3000/4157 | Cost: 0.0548\n", + "Epoch: 034/050 | Batch 4000/4157 | Cost: 0.1300\n", + "training accuracy: 97.66%\n", + "valid accuracy: 96.11%\n", + "Time elapsed: 305.03 min\n", + "Epoch: 035/050 | Batch 000/4157 | Cost: 0.0991\n", + "Epoch: 035/050 | Batch 1000/4157 | Cost: 0.0469\n", + "Epoch: 035/050 | Batch 2000/4157 | Cost: 0.0113\n", + "Epoch: 035/050 | Batch 3000/4157 | Cost: 0.0996\n", + "Epoch: 035/050 | Batch 4000/4157 | Cost: 0.1408\n", + "training accuracy: 97.69%\n", + "valid accuracy: 96.24%\n", + "Time elapsed: 314.03 min\n", + "Epoch: 036/050 | Batch 000/4157 | Cost: 0.0788\n", + "Epoch: 036/050 | Batch 1000/4157 | Cost: 0.0489\n", + "Epoch: 036/050 | Batch 2000/4157 | Cost: 0.1000\n", + "Epoch: 036/050 | Batch 3000/4157 | Cost: 0.0713\n", + "Epoch: 036/050 | Batch 4000/4157 | Cost: 0.0700\n", + "training accuracy: 97.70%\n", + "valid accuracy: 96.24%\n", + "Time elapsed: 323.07 min\n", + "Epoch: 037/050 | Batch 000/4157 | Cost: 0.0530\n", + "Epoch: 037/050 | Batch 1000/4157 | Cost: 0.1012\n", + "Epoch: 037/050 | Batch 2000/4157 | Cost: 0.0592\n", + "Epoch: 037/050 | Batch 3000/4157 | Cost: 0.1032\n", + "Epoch: 037/050 | Batch 4000/4157 | Cost: 0.0435\n", + "training accuracy: 97.64%\n", + "valid accuracy: 96.25%\n", + "Time elapsed: 332.01 min\n", + "Epoch: 038/050 | Batch 000/4157 | Cost: 0.0605\n", + "Epoch: 038/050 | Batch 1000/4157 | Cost: 0.1039\n", + "Epoch: 038/050 | Batch 2000/4157 | Cost: 0.0889\n", + "Epoch: 038/050 | Batch 3000/4157 | Cost: 0.0954\n", + "Epoch: 038/050 | Batch 4000/4157 | Cost: 0.0890\n", + "training accuracy: 97.69%\n", + "valid accuracy: 96.24%\n", + "Time elapsed: 341.00 min\n", + "Epoch: 039/050 | Batch 000/4157 | Cost: 0.0313\n", + "Epoch: 039/050 | Batch 1000/4157 | Cost: 0.1955\n", + "Epoch: 039/050 | Batch 2000/4157 | Cost: 0.1388\n", + "Epoch: 039/050 | Batch 3000/4157 | Cost: 0.0850\n", + "Epoch: 039/050 | Batch 4000/4157 | Cost: 0.0574\n", + "training accuracy: 97.71%\n", + "valid accuracy: 96.23%\n", + "Time elapsed: 350.03 min\n", + "Epoch: 040/050 | Batch 000/4157 | Cost: 0.0289\n", + "Epoch: 040/050 | Batch 1000/4157 | Cost: 0.0602\n", + "Epoch: 040/050 | Batch 2000/4157 | Cost: 0.0735\n", + "Epoch: 040/050 | Batch 3000/4157 | Cost: 0.0592\n", + "Epoch: 040/050 | Batch 4000/4157 | Cost: 0.0692\n", + "training accuracy: 97.62%\n", + "valid accuracy: 96.17%\n", + "Time elapsed: 359.08 min\n", + "Epoch: 041/050 | Batch 000/4157 | Cost: 0.0815\n", + "Epoch: 041/050 | Batch 1000/4157 | Cost: 0.0868\n", + "Epoch: 041/050 | Batch 2000/4157 | Cost: 0.0714\n", + "Epoch: 041/050 | Batch 3000/4157 | Cost: 0.1631\n", + "Epoch: 041/050 | Batch 4000/4157 | Cost: 0.0758\n", + "training accuracy: 97.72%\n", + "valid accuracy: 96.29%\n", + "Time elapsed: 367.99 min\n", + "Epoch: 042/050 | Batch 000/4157 | Cost: 0.0591\n", + "Epoch: 042/050 | Batch 1000/4157 | Cost: 0.0564\n", + "Epoch: 042/050 | Batch 2000/4157 | Cost: 0.0635\n", + "Epoch: 042/050 | Batch 3000/4157 | Cost: 0.1051\n", + "Epoch: 042/050 | Batch 4000/4157 | Cost: 0.0734\n", + "training accuracy: 97.64%\n", + "valid accuracy: 96.14%\n", + "Time elapsed: 377.04 min\n", + "Epoch: 043/050 | Batch 000/4157 | Cost: 0.0693\n", + "Epoch: 043/050 | Batch 1000/4157 | Cost: 0.0590\n", + "Epoch: 043/050 | Batch 2000/4157 | Cost: 0.0638\n", + "Epoch: 043/050 | Batch 3000/4157 | Cost: 0.0658\n", + "Epoch: 043/050 | Batch 4000/4157 | Cost: 0.0599\n", + "training accuracy: 97.76%\n", + "valid accuracy: 96.38%\n", + "Time elapsed: 386.09 min\n", + "Epoch: 044/050 | Batch 000/4157 | Cost: 0.0503\n", + "Epoch: 044/050 | Batch 1000/4157 | Cost: 0.1081\n", + "Epoch: 044/050 | Batch 2000/4157 | Cost: 0.0783\n", + "Epoch: 044/050 | Batch 3000/4157 | Cost: 0.0634\n", + "Epoch: 044/050 | Batch 4000/4157 | Cost: 0.1016\n", + "training accuracy: 97.62%\n", + "valid accuracy: 96.20%\n", + "Time elapsed: 395.10 min\n", + "Epoch: 045/050 | Batch 000/4157 | Cost: 0.0675\n", + "Epoch: 045/050 | Batch 1000/4157 | Cost: 0.1789\n", + "Epoch: 045/050 | Batch 2000/4157 | Cost: 0.0497\n", + "Epoch: 045/050 | Batch 3000/4157 | Cost: 0.0718\n", + "Epoch: 045/050 | Batch 4000/4157 | Cost: 0.1590\n", + "training accuracy: 97.68%\n", + "valid accuracy: 96.25%\n", + "Time elapsed: 404.06 min\n", + "Epoch: 046/050 | Batch 000/4157 | Cost: 0.1274\n", + "Epoch: 046/050 | Batch 1000/4157 | Cost: 0.1153\n", + "Epoch: 046/050 | Batch 2000/4157 | Cost: 0.1211\n", + "Epoch: 046/050 | Batch 3000/4157 | Cost: 0.0819\n", + "Epoch: 046/050 | Batch 4000/4157 | Cost: 0.1036\n", + "training accuracy: 97.73%\n", + "valid accuracy: 96.24%\n", + "Time elapsed: 413.10 min\n", + "Epoch: 047/050 | Batch 000/4157 | Cost: 0.1166\n", + "Epoch: 047/050 | Batch 1000/4157 | Cost: 0.0465\n", + "Epoch: 047/050 | Batch 2000/4157 | Cost: 0.1046\n", + "Epoch: 047/050 | Batch 3000/4157 | Cost: 0.0449\n", + "Epoch: 047/050 | Batch 4000/4157 | Cost: 0.1335\n", + "training accuracy: 97.68%\n", + "valid accuracy: 96.31%\n", + "Time elapsed: 422.12 min\n", + "Epoch: 048/050 | Batch 000/4157 | Cost: 0.0980\n", + "Epoch: 048/050 | Batch 1000/4157 | Cost: 0.0845\n", + "Epoch: 048/050 | Batch 2000/4157 | Cost: 0.0559\n", + "Epoch: 048/050 | Batch 3000/4157 | Cost: 0.0261\n", + "Epoch: 048/050 | Batch 4000/4157 | Cost: 0.0484\n", + "training accuracy: 97.69%\n", + "valid accuracy: 96.28%\n", + "Time elapsed: 431.10 min\n", + "Epoch: 049/050 | Batch 000/4157 | Cost: 0.0621\n", + "Epoch: 049/050 | Batch 1000/4157 | Cost: 0.0815\n", + "Epoch: 049/050 | Batch 2000/4157 | Cost: 0.0569\n", + "Epoch: 049/050 | Batch 3000/4157 | Cost: 0.1636\n", + "Epoch: 049/050 | Batch 4000/4157 | Cost: 0.0797\n", + "training accuracy: 97.58%\n", + "valid accuracy: 96.11%\n", + "Time elapsed: 440.10 min\n", + "Epoch: 050/050 | Batch 000/4157 | Cost: 0.0517\n", + "Epoch: 050/050 | Batch 1000/4157 | Cost: 0.0388\n", + "Epoch: 050/050 | Batch 2000/4157 | Cost: 0.0833\n", + "Epoch: 050/050 | Batch 3000/4157 | Cost: 0.1234\n", + "Epoch: 050/050 | Batch 4000/4157 | Cost: 0.0752\n", + "training accuracy: 97.67%\n", + "valid accuracy: 96.31%\n", + "Time elapsed: 449.14 min\n", + "Total Training Time: 449.14 min\n", + "Test accuracy: 96.48%\n" + ] + } + ], + "source": [ + "start_time = time.time()\n", + "\n", + "for epoch in range(NUM_EPOCHS):\n", + " model.train()\n", + " for batch_idx, batch_data in enumerate(train_loader):\n", + " \n", + " text, text_lengths = batch_data.content\n", + " \n", + " ### FORWARD AND BACK PROP\n", + " logits = model(text, text_lengths).squeeze(1)\n", + " cost = F.cross_entropy(logits, batch_data.classlabel.long())\n", + " optimizer.zero_grad()\n", + " \n", + " cost.backward()\n", + " \n", + " ### UPDATE MODEL PARAMETERS\n", + " optimizer.step()\n", + " \n", + " ### LOGGING\n", + " if not batch_idx % 1000:\n", + " print (f'Epoch: {epoch+1:03d}/{NUM_EPOCHS:03d} | '\n", + " f'Batch {batch_idx:03d}/{len(train_loader):03d} | '\n", + " f'Cost: {cost:.4f}')\n", + "\n", + " with torch.set_grad_enabled(False):\n", + " print(f'training accuracy: '\n", + " f'{compute_accuracy(model, train_loader, DEVICE):.2f}%'\n", + " f'\\nvalid accuracy: '\n", + " f'{compute_accuracy(model, valid_loader, DEVICE):.2f}%')\n", + " \n", + " print(f'Time elapsed: {(time.time() - start_time)/60:.2f} min')\n", + " \n", + "print(f'Total Training Time: {(time.time() - start_time)/60:.2f} min')\n", + "print(f'Test accuracy: {compute_accuracy(model, test_loader, DEVICE):.2f}%')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Evaluation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Evaluating on some new text that has been collected from recent Yelp reviews and are not part of the training or test sets." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "jt55pscgFdKZ" + }, + "outputs": [], + "source": [ + "import spacy\n", + "nlp = spacy.load('en')\n", + "\n", + "\n", + "map_dictionary = {\n", + " 0: \"negative\",\n", + " 1: \"positive\"\n", + "}\n", + "\n", + "\n", + "def predict_class(model, sentence, min_len=4):\n", + " # Somewhat based on\n", + " # https://github.com/bentrevett/pytorch-sentiment-analysis/\n", + " # blob/master/5%20-%20Multi-class%20Sentiment%20Analysis.ipynb\n", + " model.eval()\n", + " tokenized = [tok.text for tok in nlp.tokenizer(sentence)]\n", + " if len(tokenized) < min_len:\n", + " tokenized += [''] * (min_len - len(tokenized))\n", + " indexed = [TEXT.vocab.stoi[t] for t in tokenized]\n", + " length = [len(indexed)]\n", + " tensor = torch.LongTensor(indexed).to(DEVICE)\n", + " tensor = tensor.unsqueeze(1)\n", + " length_tensor = torch.LongTensor(length)\n", + " preds = model(tensor, length_tensor)\n", + " preds = torch.softmax(preds, dim=1)\n", + "\n", + " proba, class_label = preds.max(dim=1)\n", + " return proba.item(), class_label.item()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 64])\n", + "Class Label: 1 -> positive\n", + "Probability: 0.9960760474205017\n" + ] + } + ], + "source": [ + "text = \"\"\"\n", + "I have returned many times since my original review, and I can attest to the fact that, indeed, \n", + "the plethora of books she provides does not disappoint. Although under new ownership, \n", + "the vibe and the focus remains unchanged. \n", + "\n", + "I still collect Kobayashi poetry anytime I stumble upon it.\n", + "\n", + "My absolute favorite bookshop, card vendor, and truth teller. \n", + "\n", + "Until next time.\n", + "\"\"\"\n", + "\n", + "proba, pred_label = predict_class(model, text)\n", + "\n", + "print(f'Class Label: {pred_label} -> {map_dictionary[pred_label]}')\n", + "print(f'Probability: {proba}')" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([1, 64])\n", + "Class Label: 0 -> negative\n", + "Probability: 0.999991774559021\n" + ] + } + ], + "source": [ + "text = \"\"\"\n", + "Horrible customer service experience!!\n", + "\n", + "Why I even bothered to go here is beyond me.. \n", + "My wife asked me to get some gift cards and my dad \n", + "mentioned that he would give me a yearly membership as a present. \n", + "I made the mistake of not listening to that little voice in my head \n", + "screaming \"DON'T!!!!\". I got the gift cards and asked for the membership \n", + "and then realized that they hadn't given me the membership. So I go in the \n", + "next day and asked someone in customer service if I could get the membership \n", + "and then have them apply the discount to the previous purchases and some new \n", + "purchases and their response was \"Of course.. Talk to Scott, our head cashier, \n", + "and he will gladly take care of this\". I go to Scott and he tells me \"I've never \n", + "done that, we would never do that and whoever told you that was obviously \n", + "wrong\" Needless to say, I did not make any new purchases and I will promptly \n", + "return any of the previous purchases and give my hard-earned money to someone who deserves it.\n", + "\n", + "Bottom line.. Overpriced lousy customer service is not for me. In this day\n", + "and age they should know better than that and you should use your buying power to show them. Stay away..\n", + "\"\"\"\n", + "\n", + "proba, pred_label = predict_class(model, text)\n", + "\n", + "print(f'Class Label: {pred_label} -> {map_dictionary[pred_label]}')\n", + "print(f'Probability: {proba}')" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "7lRusB3dF80X" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "pandas 0.24.2\n", + "torch 1.3.0\n", + "numpy 1.17.2\n", + "spacy 2.2.3\n", + "torchtext 0.4.0\n", + "\n" + ] + } + ], + "source": [ + "%watermark -iv" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "torch.save(model.state_dict(), 'rnn_bi_multilayer_lstm_own_csv_yelp-polarity.pt')" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "name": "rnn_lstm_packed_imdb.ipynb", + "provenance": [], + "version": "0.3.2" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pytorch_ipynb/rnn/rnn_lstm_packed_own_csv_imdb.ipynb b/pytorch_ipynb/rnn/rnn_lstm_packed_own_csv_imdb.ipynb index 2402f62..7d681bd 100644 --- a/pytorch_ipynb/rnn/rnn_lstm_packed_own_csv_imdb.ipynb +++ b/pytorch_ipynb/rnn/rnn_lstm_packed_own_csv_imdb.ipynb @@ -119,28 +119,28 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "--2019-04-28 01:19:02-- https://github.com/rasbt/python-machine-learning-book-2nd-edition/raw/master/code/ch08/movie_data.csv.gz\n", - "Resolving github.com (github.com)... 192.30.253.113, 192.30.253.112\n", - "Connecting to github.com (github.com)|192.30.253.113|:443... connected.\n", + "--2019-11-28 19:47:46-- https://github.com/rasbt/python-machine-learning-book-2nd-edition/raw/master/code/ch08/movie_data.csv.gz\n", + "Resolving github.com (github.com)... 140.82.113.3\n", + "Connecting to github.com (github.com)|140.82.113.3|:443... connected.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://raw.githubusercontent.com/rasbt/python-machine-learning-book-2nd-edition/master/code/ch08/movie_data.csv.gz [following]\n", - "--2019-04-28 01:19:02-- https://raw.githubusercontent.com/rasbt/python-machine-learning-book-2nd-edition/master/code/ch08/movie_data.csv.gz\n", + "--2019-11-28 19:47:46-- https://raw.githubusercontent.com/rasbt/python-machine-learning-book-2nd-edition/master/code/ch08/movie_data.csv.gz\n", "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.184.133\n", "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.184.133|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: 26521894 (25M) [application/octet-stream]\n", "Saving to: ‘movie_data.csv.gz’\n", "\n", - "movie_data.csv.gz 100%[===================>] 25.29M 57.1MB/s in 0.4s \n", + "movie_data.csv.gz 100%[===================>] 25.29M 10.5MB/s in 2.4s \n", "\n", - "2019-04-28 01:19:03 (57.1 MB/s) - ‘movie_data.csv.gz’ saved [26521894/26521894]\n", + "2019-11-28 19:47:49 (10.5 MB/s) - ‘movie_data.csv.gz’ saved [26521894/26521894]\n", "\n" ] } @@ -151,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -925,9 +925,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.7.1" + "version": "3.7.3" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 }