Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train status web interface #141

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
4caca25
Start work on visualization utils
EricZeiberg Aug 10, 2015
914679a
Remove unnessesary css file
EricZeiberg Aug 10, 2015
40c1821
Add graph
EricZeiberg Aug 11, 2015
3ce252a
Add overall completion bar
EricZeiberg Aug 11, 2015
760b1d8
Remove test data
EricZeiberg Aug 11, 2015
7a39279
Add documentation to readme
EricZeiberg Aug 11, 2015
5207253
Update Readme.md
EricZeiberg Aug 11, 2015
e4edbed
Fix a couple of bugs
EricZeiberg Aug 11, 2015
251b040
Merge branch 'master' of https://github.com/EricZeiberg/char-rnn
EricZeiberg Aug 11, 2015
c1d4e6b
Definition before implementation...
EricZeiberg Aug 11, 2015
cd07f17
Round values
EricZeiberg Aug 11, 2015
e932d73
Fix a couple more bugs
EricZeiberg Aug 11, 2015
4b87f27
Remove epoch rounding since its already at 3 decimal places
EricZeiberg Aug 11, 2015
636fe50
Truncate epoch to 3 decimal places
EricZeiberg Aug 11, 2015
4a00e56
Update .gitignore
EricZeiberg Aug 12, 2015
18118e3
Remove txt files
EricZeiberg Aug 12, 2015
e2226d3
Explain truncate
EricZeiberg Aug 12, 2015
a144a22
Remove scale lines
EricZeiberg Aug 12, 2015
d37bcfd
Format HTML
EricZeiberg Aug 12, 2015
6e62bcc
Add exact training loss value text
EricZeiberg Aug 12, 2015
6afae42
Fix issue with circle bar going backwards
EricZeiberg Aug 12, 2015
a842b15
Apparently this can happen too
EricZeiberg Aug 12, 2015
dc9bee4
Add some graph optimization
EricZeiberg Aug 12, 2015
0b73252
More optimizations
EricZeiberg Aug 12, 2015
f062698
Change graphing library to CanvasJS (wow)
EricZeiberg Aug 12, 2015
4396185
Put JS code in seperate file
EricZeiberg Aug 12, 2015
5edfaeb
Final touches
EricZeiberg Aug 12, 2015
62d5b24
merge web visualizer from EricZeiberg/char-rnn
whackashoe Dec 15, 2015
e1cbb11
implement web visualization using websockets
whackashoe Dec 15, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@ $ luarocks install cltorch
$ luarocks install clnn
```

If you'd like to run web visualization you will need to install the `lua-websockets` and `luajson` packages, then you may use the option `-visualize [port]` during training. Recommended to use port 8080.

```bash
$ luarocks install lua-websockets
$ luarocks install luajson
```

## Usage

### Data
Expand All @@ -65,6 +72,9 @@ $ th train.lua -data_dir data/some_folder -rnn_size 512 -num_layers 2 -dropout 0

**Checkpoints.** While the model is training it will periodically write checkpoint files to the `cv` folder. The frequency with which these checkpoints are written is controlled with number of iterations, as specified with the `eval_val_every` option (e.g. if this is 1 then a checkpoint is written every iteration). The filename of these checkpoints contains a very important number: the **loss**. For example, a checkpoint with filename `lm_lstm_epoch0.95_2.0681.t7` indicates that at this point the model was on epoch 0.95 (i.e. it has almost done one full pass over the training data), and the loss on validation data was 2.0681. This number is very important because the lower it is, the better the checkpoint works. Once you start to generate data (discussed below), you will want to use the model checkpoint that reports the lowest validation loss. Notice that this might not necessarily be the last checkpoint at the end of training (due to possible overfitting).

If you would like to visualize the training, add the `-visualize` flag. This will enable the reporting of various statistics to a webpage, that you can then access in your browser via your choice of any simple HTTP server. The recommended one is the Python `SimpleHTTPServer`, which you can start up by installing Python and running the command `python -m SimpleHTTPServer <port>`.


Another important quantities to be aware of are `batch_size` (call it B), `seq_length` (call it S), and the `train_frac` and `val_frac` settings. The batch size specifies how many streams of data are processed in parallel at one time. The sequence length specifies the length of each stream, which is also the limit at which the gradients can propagate backwards in time. For example, if `seq_length` is 20, then the gradient signal will never backpropagate more than 20 time steps, and the model might not *find* dependencies longer than this length in number of characters. Thus, if you have a very difficult dataset where there are a lot of long-term dependencies you will want to increase this setting. Now, if at runtime your input text file has N characters, these first all get split into chunks of size `BxS`. These chunks then get allocated across three splits: train/val/test according to the `frac` settings. By default `train_frac` is 0.95 and `val_frac` is 0.05, which means that 95% of our data chunks will be trained on and 5% of the chunks will be used to estimate the validation loss (and hence the generalization). If your data is small, it's possible that with the default settings you'll only have very few chunks in total (for example 100). This is bad: In these cases you may want to decrease batch size or sequence length.

Note that you can also initialize parameters from a previously saved checkpoint using `init_from`.
Expand Down
95 changes: 92 additions & 3 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ cmd:option('-data_dir','data/tinyshakespeare','data directory. Should contain th
cmd:option('-rnn_size', 128, 'size of LSTM internal state')
cmd:option('-num_layers', 2, 'number of layers in the LSTM')
cmd:option('-model', 'lstm', 'lstm,gru or rnn')
-- Training visualization
cmd:option('-visualize', 0, 'port to expose websocket server visualization of model training in browser. defaults to disabled.')
-- optimization
cmd:option('-learning_rate',2e-3,'learning rate')
cmd:option('-learning_rate_decay',0.97,'learning rate decay')
Expand Down Expand Up @@ -67,6 +69,8 @@ cmd:text()
-- parse input params
opt = cmd:parse(arg)
torch.manualSeed(opt.seed)


-- train / val / test split for data, in fractions
local test_frac = math.max(0, 1 - (opt.train_frac + opt.val_frac))
local split_sizes = {opt.train_frac, opt.val_frac, test_frac}
Expand Down Expand Up @@ -107,6 +111,22 @@ if opt.gpuid >= 0 and opt.opencl == 1 then
end
end

if opt.visualize > 0 then
local ok, ok2, ok3
ok, copas = pcall(require, 'copas')
ok2, websocket = pcall(require, 'websocket')
ok3, json = pcall(require, 'json')
if not ok then print('package copas not found!') end
if not ok2 then print('package lua-websockets not found!') end
if not ok3 then print('package luajson not found!') end
if ok and ok2 and ok3 then
print('visualization has been enabled on 127.0.0.1:' .. opt.visualize)
else
print('Could not start visualization. Please ensure you have properly installed copas, lua-websockets, and luajson')
opt.visualize = 0
end
end

-- create the data loader class
local loader = CharSplitLMMinibatchLoader.create(opt.data_dir, opt.batch_size, opt.seq_length, split_sizes)
local vocab_size = loader.vocab_size -- the number of distinct characters
Expand Down Expand Up @@ -195,6 +215,7 @@ if opt.model == 'lstm' then
end

print('number of parameters in the model: ' .. params:nElement())

-- make a bunch of clones after flattening, as that reallocates memory
clones = {}
for name,proto in pairs(protos) do
Expand Down Expand Up @@ -307,8 +328,66 @@ local optim_state = {learningRate = opt.learning_rate, alpha = opt.decay_rate}
local iterations = opt.max_epochs * loader.ntrain
local iterations_per_epoch = loader.ntrain
local loss0 = nil

if opt.visualize > 0 then
ws_iteration = 0
ws_batch_time = 0
epoch = 0
websocket.server.copas.listen
{
port = opt.visualize,
protocols = {
-- this callback is called, whenever a new client connects.
-- ws is a new websocket instance
monitor = function(ws)
json_data = {}
json_data["type"] = "initial"
json_data["iterations"] = iterations
json_data["iterations_per_epoch"] = iterations_per_epoch
json_data["print_every"] = opt.print_every
json_data["train_loss"] = {}
for i = 1, ws_iteration do
if i - 1 % opt.print_every == 0 then
table.insert(json_data["train_loss"], train_losses[i])
end
end

ws:send(json.encode(json_data))
while true do
local msg = ws:receive()
if msg == nil then
ws:close()
break
else
local js = json.decode(msg)
if js.type == "update" then
json_data = {}
json_data["type"] = "update"
json_data["iteration"] = ws_iteration
json_data["epoch"] = tonumber(string.format("%.3f", epoch))
json_data["batch_time"] = ws_batch_time
json_data["train_loss"] = {}
for i = js.last_iteration + 1, ws_iteration do
if i % opt.print_every == 0 then
table.insert(json_data["train_loss"], train_losses[i])
end
end

ws:send(json.encode(json_data))
else
ws:close()
return
end
end
end
end
}
}
end


for i = 1, iterations do
local epoch = i / loader.ntrain
epoch = i / loader.ntrain

local timer = torch.Timer()
local _, loss = optim.rmsprop(feval, params, optim_state)
Expand All @@ -321,7 +400,11 @@ for i = 1, iterations do
cutorch.synchronize()
end
local time = timer:time().real

if opt.visualize > 0 then
-- set visualization specific vars
ws_iteration = i
ws_batch_time = time
end
local train_loss = loss[1] -- the loss is inside a list, pop it
train_losses[i] = train_loss

Expand Down Expand Up @@ -357,9 +440,13 @@ for i = 1, iterations do
if i % opt.print_every == 0 then
print(string.format("%d/%d (epoch %.3f), train_loss = %6.8f, grad/param norm = %6.4e, time/batch = %.4fs", i, iterations, epoch, train_loss, grad_params:norm() / params:norm(), time))
end

if i % 10 == 0 then collectgarbage() end

if opt.visualize > 0 then
copas.step(1)
end

-- handle early stopping if things are going really bad
if loss[1] ~= loss[1] then
print('loss is NaN. This usually indicates a bug. Please check the issues page for existing issues, or create a new issue, if none exist. Ideally, please state: your operating system, 32-bit/64-bit, your blas version, cpu/cuda/cl?')
Expand All @@ -373,3 +460,5 @@ for i = 1, iterations do
end




5 changes: 5 additions & 0 deletions web_utils/bootstrap.min.css

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions web_utils/bootstrap.min.js

Large diffs are not rendered by default.

Loading