Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full experiment with ordering of dimensions in the forward pass for the input of type DenseArray{<:Real,3} #2

Open
AzamatB opened this issue Feb 4, 2020 · 3 comments

Comments

@AzamatB
Copy link
Owner

AzamatB commented Feb 4, 2020

Implementations of 6 different versions of the forward pass, one for each of the permutation of D(input dimension), T(time duration) and B(batch size):

"""
D × T × B
"""
function fdtb(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end
"""
D × B × T
"""
function fdbt(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = permutedims(m.listen(Xs), [1,3,2])
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, :, axes(Hs,3)))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      αᵢs = softmax(hcat(Eᵢs...); dims=2)
      # αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, batch_size, :) .* Hs; dims=3); dims=3)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end
"""
T × D × B
"""
function ftdb(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [2,1,3])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, :,1, batch_size) .* Hs; dims=1); dims=1)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end
"""
B × T × D
"""
function fbtd(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [3,2,1])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      αᵢs = softmax(hcat(Eᵢs...); dims=2)
      # αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = permutedims(dropdims(sum(αᵢs .* Hs; dims=2); dims=2))
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end
"""
T × B × D
"""
function ftbd(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [2,3,1])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = permutedims(dropdims(sum(αᵢs .* Hs; dims=1); dims=1))
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end
"""
B × D × T
"""
function fbdt(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   Hs = permutedims(m.listen(Xs), [3,1,2])
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      αᵢs = softmax(hcat(Eᵢs...); dims=2)
      # αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = permutedims(dropdims(sum(reshape(αᵢs, batch_size, 1, :) .* Hs; dims=3); dims=3))
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end

function gfdtb(m, Xs, θ)
   gradient(θ) do
      sum(sum(fdtb(m, Xs)))
   end
end
function gfdbt(m, Xs, θ)
   gradient(θ) do
      sum(sum(fdbt(m, Xs)))
   end
end
function gftdb(m, Xs, θ)
   gradient(θ) do
      sum(sum(ftdb(m, Xs)))
   end
end
function gfbtd(m, Xs, θ)
   gradient(θ) do
      sum(sum(fbtd(m, Xs)))
   end
end
function gftbd(m, Xs, θ)
   gradient(θ) do
      sum(sum(ftbd(m, Xs)))
   end
end
function gfbdt(m, Xs, θ)
   gradient(θ) do
      sum(sum(fbdt(m, Xs)))
   end
end

Was used smallish size neural net with the following dimensions

encoder_dims = (
   blstm       = (in = 39, out = 64),
   pblstms_out = (64, 64, 64)
)
attention_dim = 128
decoder_out_dims = (128, 64)
m = LAS(encoder_dims, attention_dim, decoder_out_dims, out_dim)
θ = Flux.params(m)
using BenchmarkTools

Results for xs = last(Xs_train); Xs = vecofmats2tensor(xs):

julia> reset!(m);

julia> @benchmark fdtb($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.94 GiB
  allocs estimate:  306948
  --------------
  minimum time:     3.294 s (11.02% GC)
  median time:      3.295 s (11.20% GC)
  mean time:        3.295 s (11.20% GC)
  maximum time:     3.296 s (11.39% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fdbt($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.94 GiB
  allocs estimate:  318347
  --------------
  minimum time:     3.485 s (10.91% GC)
  median time:      3.514 s (10.76% GC)
  mean time:        3.514 s (10.76% GC)
  maximum time:     3.543 s (10.61% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftdb($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  3.54 GiB
  allocs estimate:  367033
  --------------
  minimum time:     4.517 s (10.70% GC)
  median time:      4.523 s (10.57% GC)
  mean time:        4.523 s (10.57% GC)
  maximum time:     4.529 s (10.44% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbtd($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  3.55 GiB
  allocs estimate:  386793
  --------------
  minimum time:     4.498 s (10.76% GC)
  median time:      4.501 s (10.51% GC)
  mean time:        4.501 s (10.51% GC)
  maximum time:     4.504 s (10.26% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftbd($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  3.55 GiB
  allocs estimate:  366273
  --------------
  minimum time:     4.461 s (10.82% GC)
  median time:      4.469 s (10.95% GC)
  mean time:        4.469 s (10.95% GC)
  maximum time:     4.477 s (11.07% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbdt($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  3.55 GiB
  allocs estimate:  387553
  --------------
  minimum time:     5.083 s (9.19% GC)
  median time:      5.083 s (9.19% GC)
  mean time:        5.083 s (9.19% GC)
  maximum time:     5.083 s (9.19% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdtb($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  17.50 GiB
  allocs estimate:  2662077
  --------------
  minimum time:     30.478 s (64.75% GC)
  median time:      30.478 s (64.75% GC)
  mean time:        30.478 s (64.75% GC)
  maximum time:     30.478 s (64.75% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdbt($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  17.51 GiB
  allocs estimate:  2689455
  --------------
  minimum time:     30.562 s (64.84% GC)
  median time:      30.562 s (64.84% GC)
  mean time:        30.562 s (64.84% GC)
  maximum time:     30.562 s (64.84% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftdb($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  18.29 GiB
  allocs estimate:  2968508
  --------------
  minimum time:     28.648 s (57.85% GC)
  median time:      28.648 s (57.85% GC)
  mean time:        28.648 s (57.85% GC)
  maximum time:     28.648 s (57.85% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbtd($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  18.32 GiB
  allocs estimate:  3001183
  --------------
  minimum time:     28.857 s (57.49% GC)
  median time:      28.857 s (57.49% GC)
  mean time:        28.857 s (57.49% GC)
  maximum time:     28.857 s (57.49% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftbd($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  18.32 GiB
  allocs estimate:  2964703
  --------------
  minimum time:     28.671 s (57.99% GC)
  median time:      28.671 s (57.99% GC)
  mean time:        28.671 s (57.99% GC)
  maximum time:     28.671 s (57.99% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbdt($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  18.32 GiB
  allocs estimate:  3008029
  --------------
  minimum time:     28.963 s (57.72% GC)
  median time:      28.963 s (57.72% GC)
  mean time:        28.963 s (57.72% GC)
  maximum time:     28.963 s (57.72% GC)
  --------------
  samples:          1
  evals/sample:     1

Results for xs = first(Xs_train); Xs = vecofmats2tensor(xs):

julia> reset!(m);

julia> @benchmark fdtb($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  1.72 GiB
  allocs estimate:  54890
  --------------
  minimum time:     2.456 s (7.91% GC)
  median time:      2.509 s (9.16% GC)
  mean time:        2.504 s (8.79% GC)
  maximum time:     2.548 s (9.26% GC)
  --------------
  samples:          3
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fdbt($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  1.72 GiB
  allocs estimate:  57409
  --------------
  minimum time:     2.532 s (8.33% GC)
  median time:      2.604 s (8.87% GC)
  mean time:        2.604 s (8.87% GC)
  maximum time:     2.676 s (9.38% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftdb($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.28 GiB
  allocs estimate:  74143
  --------------
  minimum time:     3.719 s (7.99% GC)
  median time:      3.754 s (8.27% GC)
  mean time:        3.754 s (8.27% GC)
  maximum time:     3.789 s (8.54% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbtd($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.29 GiB
  allocs estimate:  78511
  --------------
  minimum time:     3.604 s (8.44% GC)
  median time:      3.623 s (8.55% GC)
  mean time:        3.623 s (8.55% GC)
  maximum time:     3.643 s (8.66% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark ftbd($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.29 GiB
  allocs estimate:  73975
  --------------
  minimum time:     3.684 s (8.50% GC)
  median time:      3.710 s (8.67% GC)
  mean time:        3.710 s (8.67% GC)
  maximum time:     3.736 s (8.84% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark fbdt($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  2.29 GiB
  allocs estimate:  78679
  --------------
  minimum time:     3.624 s (8.84% GC)
  median time:      3.636 s (8.83% GC)
  mean time:        3.636 s (8.83% GC)
  maximum time:     3.647 s (8.81% GC)
  --------------
  samples:          2
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdtb($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  5.75 GiB
  allocs estimate:  362721
  --------------
  minimum time:     8.457 s (34.28% GC)
  median time:      8.457 s (34.28% GC)
  mean time:        8.457 s (34.28% GC)
  maximum time:     8.457 s (34.28% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfdbt($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  5.75 GiB
  allocs estimate:  368787
  --------------
  minimum time:     8.571 s (33.81% GC)
  median time:      8.571 s (33.81% GC)
  mean time:        8.571 s (33.81% GC)
  maximum time:     8.571 s (33.81% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftdb($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  6.48 GiB
  allocs estimate:  440883
  --------------
  minimum time:     10.769 s (34.73% GC)
  median time:      10.769 s (34.73% GC)
  mean time:        10.769 s (34.73% GC)
  maximum time:     10.769 s (34.73% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbtd($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  6.50 GiB
  allocs estimate:  448102
  --------------
  minimum time:     10.603 s (35.30% GC)
  median time:      10.603 s (35.30% GC)
  mean time:        10.603 s (35.30% GC)
  maximum time:     10.603 s (35.30% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gftbd($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  6.50 GiB
  allocs estimate:  440038
  --------------
  minimum time:     10.768 s (34.75% GC)
  median time:      10.768 s (34.75% GC)
  mean time:        10.768 s (34.75% GC)
  maximum time:     10.768 s (34.75% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m);

julia> @benchmark gfbdt($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  6.50 GiB
  allocs estimate:  449619
  --------------
  minimum time:     10.641 s (35.38% GC)
  median time:      10.641 s (35.38% GC)
  mean time:        10.641 s (35.38% GC)
  maximum time:     10.641 s (35.38% GC)
  --------------
  samples:          1
  evals/sample:     1

Conclusion: D × T × B and D × B × T orderings seem to be the most efficient ones, although not much difference between all of the versions and as T grows all computations are dominated by garbage collection and speed differences almost vanish.

@AzamatB
Copy link
Owner Author

AzamatB commented Feb 4, 2020

Results for computing energy vector by taking the diagonal of the result of multiplication of query and keys matrices (f0) vs their columnwise dot products (f1):

function f0(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
      # compute energies
      Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end

function f1(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
   batch_size = size(Xs,3)
   batch_axis = axes(Xs,3)
   # compute input encoding, which are also values for the attention layer
   Hs = m.listen(Xs)
   # precompute keys ψ(H)
   ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
   # compute inital decoder state for a batch
   m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
   ŷs = broadcast(1:maxT) do _
      # compute query ϕ(sᵢ)
      ϕsᵢ = view.(Ref(m.attention_ϕ(m.state.decoding)), :, batch_axis)
      # compute energies
      Eᵢs = (ψh -> ϕsᵢ .⋅ view.(Ref(ψh), :, batch_axis)).(ψhs)
      # compute attentions weights
      # αᵢs = softmax(hcat(Eᵢs...); dims=2)
      αᵢs = softmax(hcat(Eᵢs...)')
      # αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
      # αᵢs = softmax(reduce(hcat, Eᵢs)')
      # αᵢs = softmax(vcat(Eᵢs'...))
      # αᵢs = softmax(reduce(vcat, Eᵢs'))
      # compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
      # hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
      m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
      # predict probability distribution over character alphabet
      m.state.prediction = m.infer([m.state.decoding; m.state.context])
      # compute decoder state
      m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
      return m.state.prediction
   end
   reset!(m)
   return ŷs
end

function gf0(m, Xs, θ)
   gradient(θ) do
      sum(sum(f0(m, Xs)))
   end
end
function gf1(m, Xs, θ)
   gradient(θ) do
      sum(sum(f1(m, Xs)))
   end
end

Results for xs = first(Xs_train); Xs = vecofmats2tensor(xs):

julia> reset!(m)

julia> @benchmark f0($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  1.72 GiB
  allocs estimate:  54890
  --------------
  minimum time:     2.461 s (7.70% GC)
  median time:      2.514 s (8.95% GC)
  mean time:        2.511 s (8.64% GC)
  maximum time:     2.556 s (9.25% GC)
  --------------
  samples:          3
  evals/sample:     1

julia> reset!(m)

julia> @benchmark f1($m, $Xs)
BenchmarkTools.Trial: 
  memory estimate:  1.40 GiB
  allocs estimate:  643228
  --------------
  minimum time:     2.222 s (7.58% GC)
  median time:      2.251 s (8.09% GC)
  mean time:        2.244 s (7.92% GC)
  maximum time:     2.261 s (8.09% GC)
  --------------
  samples:          3
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gf0($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  5.75 GiB
  allocs estimate:  362721
  --------------
  minimum time:     8.926 s (34.64% GC)
  median time:      8.926 s (34.64% GC)
  mean time:        8.926 s (34.64% GC)
  maximum time:     8.926 s (34.64% GC)
  --------------
  samples:          1
  evals/sample:     1

julia> reset!(m)

julia> @benchmark gf1($m, $Xs, $θ)
BenchmarkTools.Trial: 
  memory estimate:  97.39 GiB
  allocs estimate:  18548950
  --------------
  minimum time:     95.674 s (28.28% GC)
  median time:      95.674 s (28.28% GC)
  mean time:        95.674 s (28.28% GC)
  maximum time:     95.674 s (28.28% GC)
  --------------
  samples:          1
  evals/sample:     1

Changing from views to getindexes in energy computation step only increased the total time to 110.982 seconds.
Conclusion: Columnwise dot products (f1) both via view or getindex are much much (more than 10x) slower (reported here) than taking the diagonal of the result of multiplication of query and keys matrices (f0) when it comes to computing gradients.

@AzamatB
Copy link
Owner Author

AzamatB commented Feb 4, 2020

Changing

ŷs = map(1:maxT) do _

to

ŷs = broadcast(1:maxT) do _

did not win any performance.

@AzamatB
Copy link
Owner Author

AzamatB commented Feb 5, 2020

Changing αᵢs = softmax(hcat(Eᵢs...)') to

  • αᵢs = softmax(stack(Eᵢs)') slighly reduces allocations, so was given preference to
  • αᵢs = softmax(reduce(hcat, Eᵢs)') throws ERROR: Can't differentiate gc_preserve_end expression
  • αᵢs = softmax(reduce(vcat, Eᵢs')) throws ERROR: Can't differentiate loopinfo expression

also with D × B × T ordering of input dimensions:

  • αᵢs = softmax(stack(Eᵢs); dims=2) slightly increases the runtime
  • αᵢs = softmax(reduce(hcat, Eᵢs); dims=2) still throws ERROR: Can't differentiate gc_preserve_end expression

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant