Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add flag for computing p(x|y) in conditional HINT #44

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/conditional_layers/conditional_layer_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
# 3D Constructor from input dimensions
ConditionalLayerHINT3D(args...; kw...) = ConditionalLayerHINT(args...; kw..., ndims=3)

function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N}
function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, x_lane::Bool=false) where {T, N}
isnothing(logdet) ? logdet = (CH.logdet && ~CH.is_reversed) : logdet = logdet

# Y-lane
Expand All @@ -96,10 +96,10 @@ function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::Conditional
# X-lane: conditional layer
logdet ? (Zx, logdet3) = CH.CL_YX.forward(Yp, X)[2:3] : Zx = CH.CL_YX.forward(Yp, X)[2]

logdet ? (return Zx, Zy, logdet1 + logdet2 + logdet3) : (return Zx, Zy)
logdet ? (return Zx, Zy, logdet1 + !x_lane*logdet2 + logdet3) : (return Zx, Zy)
end

function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing) where {T, N}
function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, x_lane::Bool=false) where {T, N}
isnothing(logdet) ? logdet = (CH.logdet && CH.is_reversed) : logdet = logdet

# Y-lane
Expand All @@ -114,18 +114,19 @@ function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::Condition
logdet ? (Xp, logdet3) = CH.CL_X.inverse(X; logdet=true) : Xp = CH.CL_X.inverse(X; logdet=false)
~isnothing(CH.C_X) ? (X = CH.C_X.inverse(Xp)) : (X = copy(Xp))

logdet ? (return X, Y, logdet1 + logdet2 + logdet3) : (return X, Y)
logdet ? (return X, Y, !x_lane*logdet1 + logdet2 + logdet3) : (return X, Y)
end

function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::ConditionalLayerHINT; logdet=nothing, set_grad::Bool=true) where {T, N}
function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N},
CH::ConditionalLayerHINT; logdet=nothing, set_grad::Bool=true, x_lane::Bool=false) where {T, N}
isnothing(logdet) ? logdet = (CH.logdet && ~CH.is_reversed) : logdet = logdet

# Y-lane
if set_grad
ΔYp, Yp = CH.CL_Y.backward(ΔZy, Zy)
ΔYp, Yp = CH.CL_Y.backward(ΔZy, Zy; x_lane=x_lane)
else
if logdet
ΔYp, Δθ_CLY, Yp, ∇logdet_CLY = CH.CL_Y.backward(ΔZy, Zy; set_grad=set_grad)
ΔYp, Δθ_CLY, Yp, ∇logdet_CLY = CH.CL_Y.backward(ΔZy, Zy; set_grad=set_grad, x_lane=x_lane)
else
ΔYp, Δθ_CLY, Yp = CH.CL_Y.backward(ΔZy, Zy; set_grad=set_grad)
end
Expand Down Expand Up @@ -185,7 +186,8 @@ function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::Abst
end
end

function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::ConditionalLayerHINT) where {T, N}
function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::AbstractArray{T, N},
Y::AbstractArray{T, N}, CH::ConditionalLayerHINT; x_lane::Bool=false) where {T, N}

# 1x1 Convolutions
if isnothing(CH.C_X) || isnothing(CH.C_Y)
Expand All @@ -204,7 +206,7 @@ function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::Abs
ΔYp += ΔYp_

# Y-lane
ΔZy, Zy = backward_inv(ΔYp, Yp, CH.CL_Y)
ΔZy, Zy = backward_inv(ΔYp, Yp, CH.CL_Y; x_lane=x_lane)

return ΔZx, ΔZy, Zx, Zy
end
Expand Down
8 changes: 4 additions & 4 deletions src/layers/invertible_layer_actnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function inverse(Y::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T,
end

# 2-3D Backward pass: Input (ΔY, Y), Output (ΔY, Y)
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, AN::ActNorm; set_grad::Bool = true) where {T, N}
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, AN::ActNorm; set_grad::Bool = true, x_lane::Bool=false) where {T, N}
inds = [i!=(N-1) ? 1 : (:) for i=1:N]
dims = collect(1:N-1); dims[end] +=1
nn = size(ΔY)[1:N-2]
Expand All @@ -106,7 +106,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, AN::ActNorm;
ΔX = ΔY .* reshape(AN.s.data, inds...)
Δs = sum(ΔY .* X, dims=dims)[inds...]
if AN.logdet
set_grad ? (Δs -= logdet_backward(nn..., AN.s)) : (Δs_ = logdet_backward(nn..., AN.s))
set_grad ? (Δs -= !x_lane*logdet_backward(nn..., AN.s)) : (Δs_ = !x_lane*logdet_backward(nn..., AN.s))
end
Δb = sum(ΔY, dims=dims)[inds...]
if set_grad
Expand All @@ -124,7 +124,7 @@ end

## Reverse-layer functions
# 2-3D Backward pass (inverse): Input (ΔX, X), Output (ΔX, X)
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, AN::ActNorm; set_grad::Bool = true) where {T, N}
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, AN::ActNorm; set_grad::Bool = true, x_lane::Bool=false) where {T, N}
inds = [i!=(N-1) ? 1 : (:) for i=1:N]
dims = collect(1:N-1); dims[end] +=1
nn = size(ΔX)[1:N-2]
Expand All @@ -133,7 +133,7 @@ function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, AN::ActN
ΔY = ΔX ./ reshape(AN.s.data, inds...)
Δs = -sum(ΔX .* X ./ reshape(AN.s.data, inds...), dims=dims)[inds...]
if AN.logdet
set_grad ? (Δs += logdet_backward(nn..., AN.s)) : (∇logdet = -logdet_backward(nn..., AN.s))
set_grad ? (Δs += !x_lane*logdet_backward(nn..., AN.s)) : (∇logdet = !x_lane*(-logdet_backward(nn..., AN.s)))
end
Δb = -sum(ΔX ./ reshape(AN.s.data, inds...), dims=dims)[inds...]
if set_grad
Expand Down
14 changes: 7 additions & 7 deletions src/layers/invertible_layer_basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ function inverse(Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLa
end

# 2D/3D Backward pass: Input (ΔY, Y), Output (ΔX, X)
function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N}
function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::AbstractArray{T, N}, Y2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true, x_lane::Bool=false) where {T, N}

# Recompute forward state
X1, X2, S = inverse(Y1, Y2, L; save=true, logdet=false)
Expand All @@ -128,15 +128,15 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst
ΔT = copy(ΔY2)
ΔS = ΔY2 .* X2
if L.logdet
set_grad && (ΔS -= coupling_logdet_backward(S))
set_grad && (ΔS -= !x_lane * coupling_logdet_backward(S))
end
ΔX2 = ΔY2 .* S
if set_grad
ΔX1 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1) + ΔY1
else
ΔX1, Δθ = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), X1; set_grad=set_grad)
if L.logdet
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(coupling_logdet_backward(S), S), 0 .*ΔT), X1; set_grad=set_grad)
_, ∇logdet = L.RB.backward(tensor_cat(L.activation.backward(!x_lane * coupling_logdet_backward(S), S), 0f0.*ΔT), X1; set_grad=set_grad)
end
ΔX1 += ΔY1
end
Expand All @@ -149,7 +149,7 @@ function backward(ΔY1::AbstractArray{T, N}, ΔY2::AbstractArray{T, N}, Y1::Abst
end

# 2D/3D Reverse backward pass: Input (ΔX, X), Output (ΔY, Y)
function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true) where {T, N}
function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::AbstractArray{T, N}, X2::AbstractArray{T, N}, L::CouplingLayerBasic; set_grad::Bool=true, x_lane::Bool=false) where {T, N}

# Recompute inverse state
Y1, Y2, S = forward(X1, X2, L; save=true, logdet=false)
Expand All @@ -158,7 +158,7 @@ function backward_inv(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, X1::
ΔT = -ΔX2 ./ S
ΔS = X2 .* ΔT
if L.logdet == true
set_grad ? (ΔS += coupling_logdet_backward(S)) : (∇logdet = -coupling_logdet_backward(S))
set_grad ? (ΔS += !x_lane*coupling_logdet_backward(S)) : (∇logdet = !x_lane*(-coupling_logdet_backward(S)))
end
if set_grad
ΔY1 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), Y1) + ΔX1
Expand Down Expand Up @@ -195,7 +195,7 @@ function jacobian(ΔX1::AbstractArray{T, N}, ΔX2::AbstractArray{T, N}, Δθ::Ab
# Gauss-Newton approximation of logdet terms
JΔθ = tensor_split(L.RB.jacobian(zeros(Float32, size(ΔX1)), Δθ, X1)[1])[1]
GNΔθ = -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, S), zeros(Float32, size(S))), X1)[2]

save ? (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward(S), GNΔθ, S) : (return ΔX1, ΔY2, X1, Y2, coupling_logdet_forward(S), GNΔθ)
else
save ? (return ΔX1, ΔY2, X1, Y2, S) : (return ΔX1, ΔY2, X1, Y2)
Expand Down Expand Up @@ -223,4 +223,4 @@ get_params(L::CouplingLayerBasic) = get_params(L.RB)
function tag_as_reversed!(L::CouplingLayerBasic, tag::Bool)
L.is_reversed = tag
return L
end
end
28 changes: 14 additions & 14 deletions src/layers/invertible_layer_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ function inverse(Y::AbstractArray{T, N} , H::CouplingLayerHINT; scale=1, permute
end

# Input are two tensors ΔY, Y
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, set_grad::Bool=true) where {T, N}
function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, set_grad::Bool=true, x_lane::Bool=false) where {T, N}
isnothing(permute) ? permute = H.permute : permute = permute

# Initializing output parameter array
Expand Down Expand Up @@ -231,16 +231,16 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL
# HINT coupling
if recursive
if set_grad
ΔXa, Xa = backward(ΔYa, Ya, H; scale=scale+1, permute="none")
ΔXa_temp, ΔXb_temp, X_temp = H.CL[scale].backward(ΔXa.*0, ΔYb, Xa, Yb)[[1,2,4]]
ΔXb, Xb = backward(ΔXb_temp, X_temp, H; scale=scale+1, permute="none")
ΔXa, Xa = backward(ΔYa, Ya, H; scale=scale+1, permute="none", x_lane=x_lane)
ΔXa_temp, ΔXb_temp, X_temp = H.CL[scale].backward(ΔXa.*0f0, ΔYb, Xa, Yb; x_lane=x_lane)[[1,2,4]]
ΔXb, Xb = backward(ΔXb_temp, X_temp, H; scale=scale+1, permute="none", x_lane=x_lane)
else
if H.logdet
ΔXa, Δθa, Xa, ∇logdet_a = backward(ΔYa, Ya, H; scale=scale+1, permute="none", set_grad=set_grad)
ΔXa_temp, ΔXb_temp, Δθ_scale, _, X_temp, ∇logdet_scale = H.CL[scale].backward(ΔXa.*0, ΔYb, Xa, Yb; set_grad=set_grad)
ΔXb, Δθb, Xb, ∇logdet_b = backward(ΔXb_temp, X_temp, H; scale=scale+1, permute="none", set_grad=set_grad)
∇logdet[1:5] .= ∇logdet_scale
∇logdet[6:5+length(∇logdet_a)] .= ∇logdet_a+∇logdet_b
∇logdet[1:5] .= !x_lane*∇logdet_scale
∇logdet[6:5+length(∇logdet_a)] .= !x_lane*∇logdet_a+!x_lane*∇logdet_b
else
ΔXa, Δθa, Xa = backward(ΔYa, Ya, H; scale=scale+1, permute="none", set_grad=set_grad)
ΔXa_temp, ΔXb_temp, Δθ_scale, _, X_temp = H.CL[scale].backward(ΔXa.*0, ΔYb, Xa, Yb; set_grad=set_grad)
Expand All @@ -254,11 +254,11 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL
Xa = copy(Ya)
ΔXa = copy(ΔYa)
if set_grad
ΔXa_, ΔXb, Xb = H.CL[scale].backward(ΔYa.*0, ΔYb, Ya, Yb)[[1,2,4]]
ΔXa_, ΔXb, Xb = H.CL[scale].backward(ΔYa.*0f0, ΔYb, Ya, Yb; x_lane=x_lane)[[1,2,4]]
else
if H.logdet
ΔXa_, ΔXb, Δθ_scale, _, Xb, ∇logdet_scale = H.CL[scale].backward(ΔYa.*0, ΔYb, Ya, Yb; set_grad=set_grad)
∇logdet[1:5] .= ∇logdet_scale
ΔXa_, ΔXb, Δθ_scale, _, Xb, ∇logdet_scale = H.CL[scale].backward(ΔYa.*0f0, ΔYb, Ya, Yb; set_grad=set_grad)
∇logdet[1:5] .= !x_lane*∇logdet_scale
else
ΔXa_, ΔXb, Δθ_scale, _, Xb = H.CL[scale].backward(ΔYa.*0, ΔYb, Ya, Yb; set_grad=set_grad)
end
Expand Down Expand Up @@ -299,7 +299,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, H::CouplingL
end

# Input are two tensors ΔX, X
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing) where {T, N}
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::CouplingLayerHINT; scale=1, permute=nothing, x_lane::Bool=false) where {T, N}
isnothing(permute) ? permute = H.permute : permute = permute

# Permutation
Expand All @@ -315,13 +315,13 @@ function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, H::Coupl

# Coupling layer backprop
if recursive
ΔY_temp, Y_temp = backward_inv(ΔXb, Xb, H; scale=scale+1, permute="none")
ΔYa_temp, ΔYb, Yb = backward_inv(0 .*ΔXa, ΔY_temp, Xa, Y_temp, H.CL[scale])[[1,2,4]]
ΔYa, Ya = backward_inv(ΔXa+ΔYa_temp, Xa, H; scale=scale+1, permute="none")
ΔY_temp, Y_temp = backward_inv(ΔXb, Xb, H; scale=scale+1, permute="none", x_lane=x_lane)
ΔYa_temp, ΔYb, Yb = backward_inv(0f0.*ΔXa, ΔY_temp, Xa, Y_temp, H.CL[scale]; x_lane=x_lane)[[1,2,4]]
ΔYa, Ya = backward_inv(ΔXa+ΔYa_temp, Xa, H; scale=scale+1, permute="none", x_lane=x_lane)
else
ΔYa = copy(ΔXa)
Ya = copy(Xa)
ΔYa_temp, ΔYb, Yb = backward_inv(0 .*ΔYa, ΔXb, Xa, Xb, H.CL[scale])[[1,2,4]]
ΔYa_temp, ΔYb, Yb = backward_inv(0f0.*ΔYa, ΔXb, Xa, Xb, H.CL[scale]; x_lane=x_lane)[[1,2,4]]
ΔYa += ΔYa_temp
end
ΔY = tensor_cat(ΔYa, ΔYb)
Expand Down
32 changes: 17 additions & 15 deletions src/networks/invertible_network_conditional_hint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,52 +78,53 @@ end
NetworkConditionalHINT3D(args...;kw...) = NetworkConditionalHINT(args...; kw..., ndims=3)

# Forward pass and compute logdet
function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkConditionalHINT; logdet=nothing) where {T, N}
function forward(X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkConditionalHINT; logdet=nothing, x_lane=false) where {T, N}
isnothing(logdet) ? logdet = (CH.logdet && ~CH.is_reversed) : logdet = logdet

depth = length(CH.CL)
logdet_ = 0
for j=1:depth
logdet ? (X_, logdet1) = CH.AN_X[j].forward(X) : X_ = CH.AN_X[j].forward(X)
logdet ? (Y_, logdet2) = CH.AN_Y[j].forward(Y) : Y_ = CH.AN_Y[j].forward(Y)
logdet ? (X, Y, logdet3) = CH.CL[j].forward(X_, Y_) : (X, Y) = CH.CL[j].forward(X_, Y_)
logdet && (logdet_ += (logdet1 + logdet2 + logdet3))
logdet ? (X, Y, logdet3) = CH.CL[j].forward(X_, Y_; x_lane=x_lane) : (X, Y) = CH.CL[j].forward(X_, Y_)
logdet && (logdet_ += (logdet1 + !x_lane*logdet2 + logdet3))
end
logdet ? (return X, Y, logdet_) : (return X, Y)
end

# Inverse pass and compute gradients
function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::NetworkConditionalHINT; logdet=nothing) where {T, N}
function inverse(Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::NetworkConditionalHINT; logdet=nothing, x_lane=false) where {T, N}
isnothing(logdet) ? logdet = (CH.logdet && CH.is_reversed) : logdet = logdet

depth = length(CH.CL)
logdet_ = 0
for j=depth:-1:1
logdet ? (Zx_, Zy_, logdet1) = CH.CL[j].inverse(Zx, Zy; logdet=true) : (Zx_, Zy_) = CH.CL[j].inverse(Zx, Zy; logdet=false)
logdet ? (Zx_, Zy_, logdet1) = CH.CL[j].inverse(Zx, Zy; logdet=true, x_lane=x_lane) : (Zx_, Zy_) = CH.CL[j].inverse(Zx, Zy; logdet=false)
logdet ? (Zy, logdet2) = CH.AN_Y[j].inverse(Zy_; logdet=true) : Zy = CH.AN_Y[j].inverse(Zy_; logdet=false)
logdet ? (Zx, logdet3) = CH.AN_X[j].inverse(Zx_; logdet=true) : Zx = CH.AN_X[j].inverse(Zx_; logdet=false)
logdet && (logdet_ += (logdet1 + logdet2 + logdet3))
logdet && (logdet_ += (logdet1 + !x_lane*logdet2 + logdet3))
end
logdet ? (return Zx, Zy, logdet_) : (return Zx, Zy)
end

# Backward pass and compute gradients
function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N}, Zy::AbstractArray{T, N}, CH::NetworkConditionalHINT; set_grad::Bool=true) where {T, N}
function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::AbstractArray{T, N},
Zy::AbstractArray{T, N}, CH::NetworkConditionalHINT; set_grad::Bool=true, x_lane::Bool=false) where {T, N}
depth = length(CH.CL)
if ~set_grad
Δθ = Array{Parameter, 1}(undef, 0)
CH.logdet && (∇logdet = Array{Parameter, 1}(undef, 0))
end
for j=depth:-1:1
if set_grad
ΔZx_, ΔZy_, Zx_, Zy_ = CH.CL[j].backward(ΔZx, ΔZy, Zx, Zy)
ΔZx_, ΔZy_, Zx_, Zy_ = CH.CL[j].backward(ΔZx, ΔZy, Zx, Zy; x_lane=x_lane)
ΔZx, Zx = CH.AN_X[j].backward(ΔZx_, Zx_)
ΔZy, Zy = CH.AN_Y[j].backward(ΔZy_, Zy_)
ΔZy, Zy = CH.AN_Y[j].backward(ΔZy_, Zy_; x_lane=x_lane)
else
if CH.logdet
ΔZx_, ΔZy_, Δθcl, Zx_, Zy_, ∇logdetcl = CH.CL[j].backward(ΔZx, ΔZy, Zx, Zy; set_grad=set_grad)
ΔZx_, ΔZy_, Δθcl, Zx_, Zy_, ∇logdetcl = CH.CL[j].backward(ΔZx, ΔZy, Zx, Zy; set_grad=set_grad, x_lane=x_lane)
ΔZx, Δθx, Zx, ∇logdetx = CH.AN_X[j].backward(ΔZx_, Zx_; set_grad=set_grad)
ΔZy, Δθy, Zy, ∇logdety = CH.AN_Y[j].backward(ΔZy_, Zy_; set_grad=set_grad)
ΔZy, Δθy, Zy, ∇logdety = CH.AN_Y[j].backward(ΔZy_, Zy_; set_grad=set_grad, x_lane=x_lane)
∇logdet = cat(∇logdetx, ∇logdety, ∇logdetcl, ∇logdet; dims=1)
else
ΔZx_, ΔZy_, Δθcl, Zx_, Zy_ = CH.CL[j].backward(ΔZx, ΔZy, Zx, Zy; set_grad=set_grad)
Expand All @@ -141,12 +142,13 @@ function backward(ΔZx::AbstractArray{T, N}, ΔZy::AbstractArray{T, N}, Zx::Abst
end

# Backward reverse pass and compute gradients
function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::AbstractArray{T, N}, Y::AbstractArray{T, N}, CH::NetworkConditionalHINT) where {T, N}
function backward_inv(ΔX::AbstractArray{T, N}, ΔY::AbstractArray{T, N}, X::AbstractArray{T, N},
Y::AbstractArray{T, N}, CH::NetworkConditionalHINT; x_lane::Bool=false) where {T, N}
depth = length(CH.CL)
for j=1:depth
ΔX_, X_ = backward_inv(ΔX, X, CH.AN_X[j])
ΔY_, Y_ = backward_inv(ΔY, Y, CH.AN_Y[j])
ΔX, ΔY, X, Y = backward_inv(ΔX_, ΔY_, X_, Y_, CH.CL[j])
ΔY_, Y_ = backward_inv(ΔY, Y, CH.AN_Y[j]; x_lane=x_lane)
ΔX, ΔY, X, Y = backward_inv(ΔX_, ΔY_, X_, Y_, CH.CL[j]; x_lane=x_lane)
end
return ΔX, ΔY, X, Y
end
Expand Down Expand Up @@ -233,4 +235,4 @@ function tag_as_reversed!(CH::NetworkConditionalHINT, tag::Bool)
tag_as_reversed!(CH.CL[j], tag)
end
return CH
end
end
10 changes: 9 additions & 1 deletion src/utils/parameter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,14 @@ function /(p1::T, p2::Parameter) where {T<:Real}
return Parameter(p1/p2.data)
end

function *(p1::Parameter, p2::Bool)
return Parameter(p1.data*p2)
end

function *(p1::Bool, p2::Parameter)
return p2*p1
end

# Shape manipulation

par2vec(x::Parameter) = vec(x.data), size(x.data)
Expand All @@ -157,4 +165,4 @@ function vec2par(x::AbstractArray{T, 1}, s::Array{Any, 1}) where T
idx_i += prod(s[i])
end
return xpar
end
end
Loading