-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Full experiment with ordering of dimensions in the forward pass for the input of type DenseArray{<:Real,3} #2
Comments
Results for computing energy vector by taking the diagonal of the result of multiplication of query and keys matrices ( function f0(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
batch_size = size(Xs,3)
# compute input encoding, which are also values for the attention layer
Hs = m.listen(Xs)
# precompute keys ψ(H)
ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
# compute inital decoder state for a batch
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
ŷs = broadcast(1:maxT) do _
# compute query ϕ(sᵢ)
ϕsᵢᵀ = m.attention_ϕ(m.state.decoding)'
# compute energies
Eᵢs = diag.((ϕsᵢᵀ,) .* ψhs)
# compute attentions weights
# αᵢs = softmax(hcat(Eᵢs...); dims=2)
αᵢs = softmax(hcat(Eᵢs...)')
# αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
# αᵢs = softmax(reduce(hcat, Eᵢs)')
# αᵢs = softmax(vcat(Eᵢs'...))
# αᵢs = softmax(reduce(vcat, Eᵢs'))
# compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
function f1(m::LAS{M}, Xs::DenseArray{R,3}, maxT::Integer = size(Xs,2))::Vector{M} where {R <: Real, M <: DenseMatrix{R}}
batch_size = size(Xs,3)
batch_axis = axes(Xs,3)
# compute input encoding, which are also values for the attention layer
Hs = m.listen(Xs)
# precompute keys ψ(H)
ψhs = m.attention_ψ.(getindex.(Ref(Hs), :, axes(Hs,2), :))
# compute inital decoder state for a batch
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context] .+ gpu(zeros(R, m.state.dim, batch_size)))
ŷs = broadcast(1:maxT) do _
# compute query ϕ(sᵢ)
ϕsᵢ = view.(Ref(m.attention_ϕ(m.state.decoding)), :, batch_axis)
# compute energies
Eᵢs = (ψh -> ϕsᵢ .⋅ view.(Ref(ψh), :, batch_axis)).(ψhs)
# compute attentions weights
# αᵢs = softmax(hcat(Eᵢs...); dims=2)
αᵢs = softmax(hcat(Eᵢs...)')
# αᵢs = softmax(reduce(hcat, Eᵢs); dims=2)
# αᵢs = softmax(reduce(hcat, Eᵢs)')
# αᵢs = softmax(vcat(Eᵢs'...))
# αᵢs = softmax(reduce(vcat, Eᵢs'))
# compute attended context by normalizing values with respect to attention weights, i.e. contextᵢ = Σᵤαᵢᵤhᵤ
# hcat(@inbounds([sum(αᵢs[b,u] * hs[u][:,b] for u ∈ eachindex(hs)) for b ∈ axes(αᵢs, 1)])...)
m.state.context = dropdims(sum(reshape(αᵢs, 1, :, batch_size) .* Hs; dims=2); dims=2)
# predict probability distribution over character alphabet
m.state.prediction = m.infer([m.state.decoding; m.state.context])
# compute decoder state
m.state.decoding = m.spell([m.state.decoding; m.state.prediction; m.state.context])
return m.state.prediction
end
reset!(m)
return ŷs
end
function gf0(m, Xs, θ)
gradient(θ) do
sum(sum(f0(m, Xs)))
end
end
function gf1(m, Xs, θ)
gradient(θ) do
sum(sum(f1(m, Xs)))
end
end Results for xs = first(Xs_train); Xs = vecofmats2tensor(xs): julia> reset!(m)
julia> @benchmark f0($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 1.72 GiB
allocs estimate: 54890
--------------
minimum time: 2.461 s (7.70% GC)
median time: 2.514 s (8.95% GC)
mean time: 2.511 s (8.64% GC)
maximum time: 2.556 s (9.25% GC)
--------------
samples: 3
evals/sample: 1
julia> reset!(m)
julia> @benchmark f1($m, $Xs)
BenchmarkTools.Trial:
memory estimate: 1.40 GiB
allocs estimate: 643228
--------------
minimum time: 2.222 s (7.58% GC)
median time: 2.251 s (8.09% GC)
mean time: 2.244 s (7.92% GC)
maximum time: 2.261 s (8.09% GC)
--------------
samples: 3
evals/sample: 1
julia> reset!(m)
julia> @benchmark gf0($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 5.75 GiB
allocs estimate: 362721
--------------
minimum time: 8.926 s (34.64% GC)
median time: 8.926 s (34.64% GC)
mean time: 8.926 s (34.64% GC)
maximum time: 8.926 s (34.64% GC)
--------------
samples: 1
evals/sample: 1
julia> reset!(m)
julia> @benchmark gf1($m, $Xs, $θ)
BenchmarkTools.Trial:
memory estimate: 97.39 GiB
allocs estimate: 18548950
--------------
minimum time: 95.674 s (28.28% GC)
median time: 95.674 s (28.28% GC)
mean time: 95.674 s (28.28% GC)
maximum time: 95.674 s (28.28% GC)
--------------
samples: 1
evals/sample: 1 Changing from |
Changing ŷs = map(1:maxT) do _ to ŷs = broadcast(1:maxT) do _ did not win any performance. |
Changing
also with
|
Implementations of 6 different versions of the forward pass, one for each of the permutation of
D
(input dimension),T
(time duration) andB
(batch size):Was used smallish size neural net with the following dimensions
Results for
xs = last(Xs_train); Xs = vecofmats2tensor(xs)
:Results for
xs = first(Xs_train); Xs = vecofmats2tensor(xs)
:Conclusion:
D × T × B
andD × B × T
orderings seem to be the most efficient ones, although not much difference between all of the versions and asT
grows all computations are dominated by garbage collection and speed differences almost vanish.The text was updated successfully, but these errors were encountered: