Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add literature background to char-rnn #351

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions text/char-rnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# RNN Character level

![char-rnn](../char-rnn/docs/rnn-train.png)

[Source](https://d2l.ai/chapter_recurrent-neural-networks/rnn.html#rnn-based-character-level-language-models)

## Model information

A recurrent neural network (RNN) outputs a prediction and a hidden state at each step of the computation. The hidden state captures historical information of a sequence (i.e. the neural network has memory) and the output is the final prediction of the model. We use this type of neural network to model sequences such as text or time series.


## Training

```shell
cd text/char-rnn
julia --project char-rnn.jl
```

## References

* [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)
* [Understanding LSTM Networks](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
* [Aston Zhang, Zachary C. Lipton, Mu Li and Alexander J. Smola, "Dive into Deep Learning", 2020](https://d2l.ai/chapter_recurrent-neural-networks/rnn.html#rnn-based-character-level-language-models)

94 changes: 85 additions & 9 deletions text/char-rnn/char-rnn.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,38 @@
# # RNN Character level

# In this example, we create a character-level recurrent neural network.
# A recurrent neural network (RNN) outputs a prediction and a hidden state at each step
# of the computation. The hidden state captures historical information of a sequence
# (i.e. the neural network has memory) and the output is the final prediction of the model.
# We use this type of neural network to model sequences such as text or time series.


# ![char-rnn](../char-rnn/docs/rnn-train.png)

# Source: https://d2l.ai/chapter_recurrent-neural-networks/rnn.html#rnn-based-character-level-language-models

# This example demonstrates the use of Flux’s implementation of the
# [Long Short Term Memory recurrent layer](https://www.researchgate.net/publication/13853244_Long_Short-term_Memory)(LSTM)
# which is an RNN that generally exhibits a longer memory span over sequences as well as
# [Flux utility functions](https://fluxml.ai/Flux.jl/stable/utilities/).

# If you need more information about how RNNs work and related technical concepts,
# check out the following resources:

# * [The Unreasonable Effectiveness of Recurrent Neural Networks](http://karpathy.github.io/2015/05/21/rnn-effectiveness/)
# * [Understanding LSTM Networks](https://colah.github.io/posts/2015-08-Understanding-LSTMs/)
# * [Illustrated Guide to Recurrent Neural Networks: Understanding the Intuition](https://www.youtube.com/watch?v=LHXXI4-IEns)

# To run this example, we need the following packages:

using Flux
using Flux: onehot, chunk, batchseq, throttle, logitcrossentropy
using StatsBase: wsample
using Base.Iterators: partition
using Parameters: @with_kw

# Hyperparameter arguments
# We set default values for the hyperparameters:

@with_kw mutable struct Args
lr::Float64 = 1e-2 # Learning rate
seqlen::Int = 50 # Length of batch sequences
Expand All @@ -13,44 +41,70 @@ using Parameters: @with_kw
epochs::Int = 2 # Number of Epochs
end

# ## Data

# We create the function `getdata` to download the training data and create arrays of batches
# for training the model:


function getdata(args)
# Download the data if not downloaded as 'input.txt'
## Download the data if not downloaded as 'input.txt'
isfile("input.txt") ||
download("https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt","input.txt")

text = collect(String(read("input.txt")))

# an array of all unique characters
## an array of all unique characters
alphabet = [unique(text)..., '_']

text = map(ch -> onehot(ch, alphabet), text)
stop = onehot('_', alphabet)

N = length(alphabet)

# Partitioning the data as sequence of batches, which are then collected as array of batches
## Partitioning the data as sequence of batches, which are then collected as array of batches
Xs = collect(partition(batchseq(chunk(text, args.nbatch), stop), args.seqlen))
Ys = collect(partition(batchseq(chunk(text[2:end], args.nbatch), stop), args.seqlen))

return Xs, Ys, N, alphabet
end

# Function to construct model
# The function `getdata` performs the following tasks:

# * Downloads a dataset of [all of Shakespeare's works (concatenated)](https://cs.stanford.edu/people/karpathy/char-rnn/shakespeare_input.txt)
# if not previously downloaded. This function loads the data as a vector of characters with the function `collect`.
# * Gets the alphabet. It consists of the unique characters of the data and the stop character ‘_’.
# * One-hot encodes the alphabet and the stop character.
# * Gets the size of the alphabet N.
# * Partitions the data as an array of batches. Note that the `Xs` array contains the sequence of characters in the text whereas the `Ys` array contains the next character of the sequence.

# ## Model

# We create the RNN with two Flux’s LSTM layers and an output layer of the size of the alphabet:

function build_model(N)
return Chain(
LSTM(N, 128),
LSTM(128, 128),
Dense(128, N))
end

# The size of the input and output layers is the same as the size of the alphabet.
# Also, we set the size of the hidden layers to 128.

# ## Train the model

# Now, we define the function `train` that creates the model and the loss function as well as the training loop:


function train(; kws...)
# Initialize the parameters
## Initialize the parameters
args = Args(; kws...)

# Get Data
## Get Data
Xs, Ys, N, alphabet = getdata(args)

# Constructing Model
## Constructing Model
m = build_model(N)

function loss(xs, ys)
Expand All @@ -71,7 +125,27 @@ function train(; kws...)
return m, alphabet
end

# Sampling
# The function `train` performs the following tasks:

# * Calls the function `getdata` to obtain the train and test data as well as the alphabet and its size.
# * Calls the function `build_model` to create the RNN.
# * Defines the loss function. For this type of neural network, we use the [logitcrossentropy](https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.logitcrossentropy)
# loss function. Notice that it is important that we call the function [reset!](https://fluxml.ai/Flux.jl/stable/models/layers/#Flux.reset!)
# before computing the loss so that it resets the hidden state of a recurrent layer back to its original value
# * Sets the [ADAM optimiser](https://fluxml.ai/Flux.jl/stable/training/optimisers/#Flux.Optimise.RADAM) with the learning rate *lr* we defined above.
# * Creates a [callback](https://fluxml.ai/Flux.jl/stable/training/training/#Callbacks) *evalcb* so that you can observe the training process (print the loss value).
# * Runs the training loop using [Flux’s train!](https://fluxml.ai/Flux.jl/stable/training/training/#Flux.Optimise.train!).
# It uses the function [throttle](https://fluxml.ai/Flux.jl/stable/utilities/#Flux.throttle) so that the callback *evalcb*
# can only be triggered at most once during timeout seconds (as defined above).

# ## Test the model

# We define the function `sample_data` to test the model.
# It generates samples of text with the alphabet that the function `getdata` computed.
# Notice that it obtains the model’s prediction by calling the
# [softmax function](https://fluxml.ai/Flux.jl/stable/models/nnlib/#Softmax)
# to get the probability distribution of the output and then it chooses randomly the prediction.

function sample_data(m, alphabet, len; seed="")
m = cpu(m)
Flux.reset!(m)
Expand All @@ -88,6 +162,8 @@ function sample_data(m, alphabet, len; seed="")
return String(take!(buf))
end

# Finally, to run this example we call the functions `train` and `sample_data`:

cd(@__DIR__)
m, alphabet = train()
sample_data(m, alphabet, 1000) |> println
Binary file added text/char-rnn/docs/rnn-train.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.