Skip to content

Commit

Permalink
add async example in knet tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
Evizero committed Mar 19, 2018
1 parent 75df3a1 commit dd367d4
Showing 1 changed file with 67 additions and 0 deletions.
67 changes: 67 additions & 0 deletions examples/mnist_knet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,73 @@ augmented_log = @time train_augmented(epochs=200);
#' performance (aside from simplifying the augmentation pipeline)
#' would be to increase the number of available threads.

#' ## Improving Performance
info("Improving Performance") #jl-only

#' One of the most effective ways to make the most out of the
#' available resources is to augment the next (couple) mini-batch
#' while the current minibatch is being processed on the GPU.
#' We can do this via julia's build in parallel computing
#' capabilities

#' First we need a worker process that will be responsible for
#' augmenting our dataset each epoch. This worker also needs
#' access to a couple of our packages

# addprocs(1)
# @everywhere using Augmentor, MLDataUtils

#' Next, we replace the inner `eachbatch` loop with a more
#' complicated version using a `RemoteChannel` to exchange and
#' queue the augmented data.

function async_train_augmented(; epochs = 500, batchsize = 100, lr = .03)
w = weights()
log = MVHistory()
p = Progress(epochs, desc = "Async Augmented: ") #jl-only
for epoch in 1:epochs
@sync begin
local_ch = Channel{Tuple}(4) # prepare up to 4 minibatches in adavnce
remote_ch = RemoteChannel(()->local_ch)
@spawn begin
# This block is executed on the worker process
batch_x_aug = zeros(Float32, size(train_x,1), size(train_x,2), 1, batchsize)
for (batch_x_cpu, batch_y) in eachbatch((train_x ,train_y), batchsize)
# we are still using multithreading
augmentbatch!(CPUThreads(), batch_x_aug, batch_x_cpu, pl)
put!(remote_ch, (batch_x_aug, batch_y))
end
close(remote_ch)
end
@async begin
# This block is executed on the main process
for (batch_x_aug, batch_y) in local_ch
batch_x = KnetArray{Float32}(batch_x_aug)
g = costgrad(w, batch_x, batch_y)
Knet.update!(w, g, lr = lr)
end
end
end

next!(p) #jl-only
if (epoch % 5) == 0
train = acc(w, train_x, train_y)
test = acc(w, test_x, test_y)
@trace log epoch train test
msg = "epoch " * lpad(epoch,4) * ": train accuracy " * rpad(round(train,3),5,"0") * ", test accuracy " * rpad(round(test,3),5,"0")
cancel(p, msg, :blue) #jl-only
#md println(msg)
#jp println(msg)
end
end
finish!(p) #jl-only
log
end
#md nothing # hide

#' Note that for this toy example the overhead of this approach
#' is greater than the benefit.

#' ## Visualizing the Results
info("Visualizing the Results") #jl-only

Expand Down

0 comments on commit dd367d4

Please sign in to comment.