Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reverse glow for MAP #65

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions src/layers/invertible_layer_actnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ function ActNorm(k; logdet=false)
end

# 2-3D Foward pass: Input X, Output Y
function forward(X::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T, N}
isnothing(logdet) ? logdet = (AN.logdet && ~AN.is_reversed) : logdet = logdet
function forward(X::AbstractArray{T, N}, AN::ActNorm;) where {T, N}
inds = [i!=(N-1) ? 1 : Colon() for i=1:N]
dims = collect(1:N-1); dims[end] +=1

Expand All @@ -73,12 +72,11 @@ function forward(X::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T,
Y = X .* reshape(AN.s.data, inds...) .+ reshape(AN.b.data, inds...)

# If logdet true, return as second ouput argument
logdet ? (return Y, logdet_forward(size(X)[1:N-2]..., AN.s)) : (return Y)
AN.logdet ? (return Y, logdet_forward(size(X)[1:N-2]..., AN.s)) : (return Y)
end

# 2-3D Inverse pass: Input Y, Output X
function inverse(Y::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T, N}
isnothing(logdet) ? logdet = (AN.logdet && AN.is_reversed) : logdet = logdet
function inverse(Y::AbstractArray{T, N}, AN::ActNorm;) where {T, N}
inds = [i!=(N-1) ? 1 : Colon() for i=1:N]
dims = collect(1:N-1); dims[end] +=1

Expand All @@ -93,7 +91,7 @@ function inverse(Y::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T,
X = (Y .- reshape(AN.b.data, inds...)) ./ reshape(AN.s.data, inds...)

# If logdet true, return as second ouput argument
logdet ? (return X, -logdet_forward(size(Y)[1:N-2]..., AN.s)) : (return X)
AN.logdet ? (return X, -logdet_forward(size(Y)[1:N-2]..., AN.s)) : (return X)
end

# 2-3D Backward pass: Input (ΔY, Y), Output (ΔY, Y)
Expand All @@ -102,7 +100,7 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, AN::ActNorm;
dims = collect(1:N-1); dims[end] +=1
nn = size(ΔY)[1:N-2]

X = inverse(Y, AN; logdet=false)
AN.logdet ? (X, logdet_i) = inverse(Y, AN;) : X = inverse(Y, AN;)
ΔX = ΔY .* reshape(AN.s.data, inds...)
Δs = sum(ΔY .* X, dims=dims)[inds...]
if AN.logdet
Expand All @@ -129,7 +127,7 @@ function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, AN::ActN
dims = collect(1:N-1); dims[end] +=1
nn = size(ΔX)[1:N-2]

Y = forward(X, AN; logdet=false)
AN.logdet ? (Y, logdet_i) = forward(X, AN;) : Y = forward(X, AN;)
ΔY = ΔX ./ reshape(AN.s.data, inds...)
Δs = -sum(ΔX .* X ./ reshape(AN.s.data, inds...), dims=dims)[inds...]
if AN.logdet
Expand All @@ -152,20 +150,19 @@ end
## Jacobian-related functions
# 2-£D
function jacobian(ΔX::AbstractArray{T, N}, Δθ::AbstractArray{Parameter, 1}, X::AbstractArray{T, N}, AN::ActNorm; logdet=nothing) where {T, N}
isnothing(logdet) ? logdet = (AN.logdet && ~AN.is_reversed) : logdet = logdet
inds = [i!=(N-1) ? 1 : Colon() for i=1:N]
nn = size(ΔX)[1:N-2]
Δs = Δθ[1].data
Δb = Δθ[2].data

# Forward evaluation
logdet ? (Y, lgdet) = forward(X, AN; logdet=logdet) : Y = forward(X, AN; logdet=logdet)
AN.logdet ? (Y, lgdet) = forward(X, AN;) : Y = forward(X, AN;)

# Jacobian evaluation
ΔY = ΔX .* reshape(AN.s.data, inds...) .+ X .* reshape(Δs, inds...) .+ reshape(Δb, inds...)

# Hessian evaluation of logdet terms
if logdet
if AN.logdet
nx, ny, _, _ = size(X)
HlogΔθ = [Parameter(logdet_hessian(nn..., AN.s).*Δs), Parameter(zeros(Float32, size(Δb)))]
return ΔY, Y, lgdet, HlogΔθ
Expand Down
64 changes: 49 additions & 15 deletions src/layers/invertible_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,24 @@ end
CouplingLayerGlow3D(args...;kw...) = CouplingLayerGlow(args...; kw..., ndims=3)

# Forward pass: Input X, Output Y
function forward(X::AbstractArray{T, N}, L::CouplingLayerGlow) where {T,N}
function forward(X::AbstractArray{T, N}, L::CouplingLayerGlow; save=false) where {T,N}
mloubout marked this conversation as resolved.
Show resolved Hide resolved

X_ = L.C.forward(X)
X1, X2 = tensor_split(X_)

Y2 = copy(X2)
logS_T = L.RB.forward(X2)
logSm, Tm = tensor_split(logS_T)
Sm = L.activation.forward(logSm)
Y1 = Sm.*X1 + Tm

Y = tensor_cat(Y1, Y2)
Y = tensor_cat(Y1, X2)

if L.logdet
save ? (return Y, Y1, X2, Sm, glow_logdet_forward(Sm)) : (return Y, glow_logdet_forward(Sm))
else
save ? (return Y, Y1, X2, Sm) : (return Y)
end

L.logdet == true ? (return Y, glow_logdet_forward(Sm)) : (return Y)
end

# Inverse pass: Input Y, Output X
Expand All @@ -120,7 +124,11 @@ function inverse(Y::AbstractArray{T, N}, L::CouplingLayerGlow; save=false) where
X_ = tensor_cat(X1, X2)
X = L.C.inverse(X_)

save == true ? (return X, X1, X2, Sm) : (return X)
if L.logdet
save ? (return X, X1, X2, Sm, -glow_logdet_forward(Sm)) : (return X, -glow_logdet_forward(Sm))
else
save ? (return X, X1, X2, Sm) : (return X)
end
end

# Backward pass: Input (ΔY, Y), Output (ΔX, X)
Expand Down Expand Up @@ -160,13 +168,37 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::CouplingL
end
end

# 2D/3D Reverse backward pass: Input (ΔX, X), Output (ΔY, Y)
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, L::CouplingLayerGlow; set_grad::Bool=true) where {T, N}

## Jacobian-related functions
ΔX, X = L.C.forward((ΔX, X))
X1, X2 = tensor_split(X)
ΔX1, ΔX2 = tensor_split(ΔX)

function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::CouplingLayerGlow) where {T,N}
# Recompute forward state
logS_T = L.RB.forward(X2)
logSm, Tm = tensor_split(logS_T)
Sm = L.activation.forward(logSm)
Y1 = Sm.*X1 + Tm

# Backpropagate residual
ΔT = -ΔX1 ./ Sm
ΔS = X1 .* ΔT
if L.logdet == true
ΔS += coupling_logdet_backward(Sm)
end

ΔY2 = L.RB.backward(tensor_cat(L.activation.backward(ΔS, Sm), ΔT), X2) + ΔX2
ΔY1 = -ΔT

ΔY = tensor_cat(ΔY1, ΔY2)
Y = tensor_cat(Y1, X2)

# Get dimensions
k = Int(L.C.k/2)
return ΔY, Y
end

## Jacobian-related functions
function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::CouplingLayerGlow) where {T,N}

ΔX_, X_ = L.C.jacobian(ΔX, Δθ[1:3], X)
X1, X2 = tensor_split(X_)
Expand All @@ -175,17 +207,19 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Array{Parameter, 1}, X, L::Cou
Y2 = copy(X2)
ΔY2 = copy(ΔX2)
ΔlogS_T, logS_T = L.RB.jacobian(ΔX2, Δθ[4:end], X2)
Sm = L.activation.forward(logS_T[:,:,1:k,:])
ΔS = L.activation.backward(ΔlogS_T[:,:,1:k,:], nothing;x=logS_T[:,:,1:k,:])
Tm = logS_T[:, :, k+1:end, :]
ΔT = ΔlogS_T[:, :, k+1:end, :]
logSm, Tm = tensor_split(logS_T)
ΔlogSm, ΔT = tensor_split(ΔlogS_T)

Sm = L.activation.forward(logSm)
ΔS = L.activation.backward(ΔlogSm, nothing;x=logSm)
Y1 = Sm.*X1 + Tm
ΔY1 = ΔS.*X1 + Sm.*ΔX1 + ΔT
Y = tensor_cat(Y1, Y2)
ΔY = tensor_cat(ΔY1, ΔY2)

# Gauss-Newton approximation of logdet terms
JΔθ = L.RB.jacobian(cuzeros(ΔX2, size(ΔX2)), Δθ[4:end], X2)[1][:, :, 1:k, :]
JΔθ = L.RB.jacobian(cuzeros(ΔX2, size(ΔX2)), Δθ[4:end], X2)[1]
JΔθ = tensor_split(JΔθ)[1]
GNΔθ = cat(0f0*Δθ[1:3], -L.RB.adjointJacobian(tensor_cat(L.activation.backward(JΔθ, Sm), zeros(Float32, size(Sm))), X2)[2]; dims=1)

L.logdet ? (return ΔY, Y, glow_logdet_forward(Sm), GNΔθ) : (return ΔY, Y)
Expand All @@ -195,6 +229,6 @@ function adjointJacobian(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, L::Co
return backward(ΔY, Y, L; set_grad=false)
end

# Logdet (correct?)
# Logdet
glow_logdet_forward(S) = sum(log.(abs.(S))) / size(S)[end]
glow_logdet_backward(S) = 1f0./ S / size(S)[end]
71 changes: 52 additions & 19 deletions src/networks/invertible_network_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,13 @@ struct NetworkGlow <: InvertibleNetwork
K::Int64
squeezer::Squeezer
split_scales::Bool
logdet::Bool
end

@Flux.functor NetworkGlow

# Constructor
function NetworkGlow(n_in, n_hidden, L, K; freeze_conv=false, split_scales=false, k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
function NetworkGlow(n_in, n_hidden, L, K; logdet=true,freeze_conv=false, split_scales=false, k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, ndims=2, squeezer::Squeezer=ShuffleLayer(), activation::ActivationFunction=SigmoidLayer())
AN = Array{ActNorm}(undef, L, K) # activation normalization
CL = Array{CouplingLayerGlow}(undef, L, K) # coupling layers w/ 1x1 convolution and residual block

Expand All @@ -87,63 +88,65 @@ function NetworkGlow(n_in, n_hidden, L, K; freeze_conv=false, split_scales=false
for i=1:L
n_in *= channel_factor # squeeze if split_scales is turned on
for j=1:K
AN[i, j] = ActNorm(n_in; logdet=true)
CL[i, j] = CouplingLayerGlow(n_in, n_hidden; freeze_conv=freeze_conv, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, logdet=true, activation=activation, ndims=ndims)
AN[i, j] = ActNorm(n_in; logdet=logdet)
CL[i, j] = CouplingLayerGlow(n_in, n_hidden; logdet=logdet, freeze_conv=freeze_conv, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, activation=activation, ndims=ndims)
end
(i < L && split_scales) && (n_in = Int64(n_in/2)) # split
end

return NetworkGlow(AN, CL, Z_dims, L, K, squeezer, split_scales)
return NetworkGlow(AN, CL, Z_dims, L, K, squeezer, split_scales, logdet)
end

NetworkGlow3D(args; kw...) = NetworkGlow(args...; kw..., ndims=3)

# Forward pass and compute logdet
function forward(X::AbstractArray{T, N}, G::NetworkGlow) where {T, N}
G.split_scales && (Z_save = array_of_array(X, G.L-1))

orig_shape = size(X)
logdet = 0
for i=1:G.L
(G.split_scales) && (X = G.squeezer.forward(X))
for j=1:G.K
X, logdet1 = G.AN[i, j].forward(X)
X, logdet2 = G.CL[i, j].forward(X)
logdet += (logdet1 + logdet2)
G.logdet ? (X, logdet1) = G.AN[i, j].forward(X) : X = G.AN[i, j].forward(X)
G.logdet ? (X, logdet2) = G.CL[i, j].forward(X) : X = G.CL[i, j].forward(X)
G.logdet && (logdet += (logdet1 + logdet2))
end
if G.split_scales && i < G.L # don't split after last iteration
X, Z = tensor_split(X)
Z_save[i] = Z
G.Z_dims[i] = collect(size(Z))
end
end
G.split_scales && (X = cat_states(Z_save, X))
return X, logdet
G.split_scales && (X = reshape(cat_states(Z_save, X),orig_shape))
G.logdet ? (return X, logdet) : (return X)
end

# Inverse pass
function inverse(X::AbstractArray{T, N}, G::NetworkGlow) where {T, N}
G.split_scales && ((Z_save, X) = split_states(X, G.Z_dims))
G.split_scales && ((Z_save, X) = split_states(X[:], G.Z_dims))
logdet = 0
for i=G.L:-1:1
if G.split_scales && i < G.L
X = tensor_cat(X, Z_save[i])
end
for j=G.K:-1:1
X = G.CL[i, j].inverse(X)
X = G.AN[i, j].inverse(X)
G.logdet ? (X, logdet1) = G.CL[i, j].inverse(X) : X = G.CL[i, j].inverse(X)
G.logdet ? (X, logdet2) = G.AN[i, j].inverse(X) : X = G.AN[i, j].inverse(X)
G.logdet && (logdet += (logdet1 + logdet2))
end

(G.split_scales) && (X = G.squeezer.inverse(X))
end
return X
G.logdet ? (return X, logdet) : (return X)
end

# Backward pass and compute gradients
function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, G::NetworkGlow; set_grad::Bool=true) where {T, N}

# Split data and gradients
if G.split_scales
ΔZ_save, ΔX = split_states(ΔX, G.Z_dims)
Z_save, X = split_states(X, G.Z_dims)
ΔZ_save, ΔX = split_states(ΔX[:], G.Z_dims)
Z_save, X = split_states(X[:], G.Z_dims)
end

if ~set_grad
Expand Down Expand Up @@ -180,14 +183,44 @@ function backward(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, G::NetworkGl
set_grad ? (return ΔX, X) : (return ΔX, vcat(ΔθAN, ΔθCL), X, vcat(∇logdetAN, ∇logdetCL))
end

# Backward reverse pass and compute gradients
function backward_inv(ΔX::AbstractArray{T, N}, X::AbstractArray{T, N}, G::NetworkGlow) where {T, N}
G.split_scales && (X_save = array_of_array(X, G.L-1))
G.split_scales && (ΔX_save = array_of_array(ΔX, G.L-1))
orig_shape = size(X)

for i=1:G.L
G.split_scales && (ΔX = G.squeezer.forward(ΔX))
G.split_scales && (X = G.squeezer.forward(X))
for j=1:G.K
ΔX_, X_ = backward_inv(ΔX, X, G.AN[i, j])
ΔX, X = backward_inv(ΔX_, X_, G.CL[i, j])
end

if G.split_scales && i < G.L # don't split after last iteration
X, Z = tensor_split(X)
ΔX, ΔZx = tensor_split(ΔX)

X_save[i] = Z
ΔX_save[i] = ΔZx

G.Z_dims[i] = collect(size(X))
end
end

G.split_scales && (X = reshape(cat_states(X_save, X), orig_shape))
G.split_scales && (ΔX = reshape(cat_states(ΔX_save, ΔX), orig_shape))
return ΔX, X
end

## Jacobian-related utils
function jacobian(ΔX::AbstractArray{T, N}, Δθ::Vector{Parameter}, X, G::NetworkGlow) where {T, N}

if G.split_scales
Z_save = array_of_array(ΔX, G.L-1)
ΔZ_save = array_of_array(ΔX, G.L-1)
end
orig_shape = size(X)

logdet = 0
cls = 2*G.K*G.L
ΔθAN = Vector{Parameter}(undef, 0)
Expand Down Expand Up @@ -217,8 +250,8 @@ function jacobian(ΔX::AbstractArray{T, N}, Δθ::Vector{Parameter}, X, G::Netwo
end
end
if G.split_scales
X = cat_states(Z_save, X)
ΔX = cat_states(ΔZ_save, ΔX)
X = reshape(cat_states(Z_save, X), orig_shape)
ΔX = reshape(cat_states(ΔZ_save, ΔX), orig_shape)
end

return ΔX, X, logdet, vcat(ΔθAN, ΔθCL)
Expand Down
13 changes: 7 additions & 6 deletions test/test_layers/test_actnorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,26 +75,27 @@ AN_rev = reverse(AN)
# Test with logdet enabled
AN = ActNorm(nc; logdet=true)
Y, lgdt = AN.forward(X)
#X_, lgdt = AN.inverse(X)

# Test initialization
@test isapprox(mean(Y), 0f0; atol=1f-6)
@test isapprox(var(Y), 1f0; atol=1f-3)

# Test invertibility
@test isapprox(norm(X - AN.inverse(AN.forward(X)[1]))/norm(X), 0f0, atol=1f-6)
@test isapprox(norm(X - AN.forward(AN.inverse(X))[1])/norm(X), 0f0, atol=1f-6)
@test isapprox(norm(X - AN.inverse(AN.forward(X)[1])[1])/norm(X), 0f0, atol=1f-6)
@test isapprox(norm(X - AN.forward(AN.inverse(X)[1])[1])/norm(X), 0f0, atol=1f-6)

# Reversed layer (all combinations)
AN_rev = reverse(AN)

@test isapprox(norm(X - AN_rev.inverse(AN_rev.forward(X)[1]))/norm(X), 0f0, atol=1f-6)
@test isapprox(norm(X - AN_rev.forward(AN_rev.inverse(X))[1])/norm(X), 0f0, atol=1f-6)
@test isapprox(norm(X - AN_rev.inverse(AN_rev.forward(X)[1])[1])/norm(X), 0f0, atol=1f-6)
@test isapprox(norm(X - AN_rev.forward(AN_rev.inverse(X)[1])[1])/norm(X), 0f0, atol=1f-6)

@test isapprox(norm(X - AN_rev.forward(AN.forward(X)[1])[1])/norm(X), 0f0, atol=1f-6)
@test isapprox(norm(X - AN_rev.inverse(AN.inverse(X)))/norm(X), 0f0, atol=1f-6)
@test isapprox(norm(X - AN_rev.inverse(AN.inverse(X)[1])[1])/norm(X), 0f0, atol=1f-6)

@test isapprox(norm(X - AN.forward(AN_rev.forward(X)[1])[1])/norm(X), 0f0, atol=1f-6)
@test isapprox(norm(X - AN.inverse(AN_rev.inverse(X)))/norm(X), 0f0, atol=1f-6)
@test isapprox(norm(X - AN.inverse(AN_rev.inverse(X)[1])[1])/norm(X), 0f0, atol=1f-6)


###############################################################################
Expand Down
Loading