From 5873e4a884710d9254cf9f2fbdc66395e87c8b9d Mon Sep 17 00:00:00 2001 From: 10-zin <33179372+10-zin@users.noreply.github.com> Date: Mon, 15 Oct 2018 16:15:20 +0530 Subject: [PATCH] Cell not needed in RNN --- tutorials/02-intermediate/recurrent_neural_network/main.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tutorials/02-intermediate/recurrent_neural_network/main.py b/tutorials/02-intermediate/recurrent_neural_network/main.py index 9b8685ca..1f21d6f2 100644 --- a/tutorials/02-intermediate/recurrent_neural_network/main.py +++ b/tutorials/02-intermediate/recurrent_neural_network/main.py @@ -48,7 +48,6 @@ def __init__(self, input_size, hidden_size, num_layers, num_classes): def forward(self, x): # Set initial hidden and cell states h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) - c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # Forward propagate LSTM out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size) @@ -99,4 +98,4 @@ def forward(self, x): print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) # Save the model checkpoint -torch.save(model.state_dict(), 'model.ckpt') \ No newline at end of file +torch.save(model.state_dict(), 'model.ckpt')