Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make sure all networks can handle 1d inputs #86

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions examples/VI/amortized_posterior_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@ end

# train using samples from joint distribution x,y ~ p(x,y) where x=[μ, σ] -> y = N(μ, σ)
# rows: μ, σ, y
num_params = 2; num_obs = 51; n_train = 10000
num_params = 2; num_obs = 51;
n_c_params = 1; n_c_obs = 1;
n_train = 10000
training_data = mapreduce(x -> generate_data(num_obs), hcat, 1:n_train);

############ train some data
X_train = reshape(training_data[1:num_params,:], (1,1,num_params,:));
Y_train = reshape(training_data[(num_params+1):end,:], (1,1,num_obs,:));
X_train = reshape(training_data[1:num_params,:], (num_params,n_c_params,:));
Y_train = reshape(training_data[(num_params+1):end,:], (num_obs,n_c_obs,:));

n_epochs = 2
batch_size = 200
Expand All @@ -31,7 +33,7 @@ using InvertibleNetworks, LinearAlgebra, Flux
L = 3 # RealNVP multiscale layers
K = 4 # Coupling layers per scale
n_hidden = 32 # Hidden channels in coupling layers' neural network
G = NetworkConditionalGlow(num_params, num_obs, n_hidden, L, K;);
G = NetworkConditionalGlow(n_c_params, n_c_obs, n_hidden, L, K; nx=nx, ndims=1);
opt = ADAM(4f-3)

# Training logs
Expand All @@ -40,8 +42,8 @@ loss_l2 = []; logdet_train = [];
for e=1:n_epochs # epoch loop
idx_e = reshape(1:n_train, batch_size, n_batches)
for b = 1:n_batches # batch loop
X = X_train[:, :, :, idx_e[:,b]];
Y = Y_train[:, :, :, idx_e[:,b]];
X = X_train[:, :, idx_e[:,b]];
Y = Y_train[:, :, idx_e[:,b]];

# Forward pass of normalizing flow
Zx, Zy, lgdet = G.forward(X, Y)
Expand Down
14 changes: 11 additions & 3 deletions src/conditional_layers/conditional_layer_glow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,15 @@ end

# Constructor from input dimensions
function ConditionalLayerGlow(n_in::Int64, n_cond::Int64, n_hidden::Int64;freeze_conv=false, k1=3, k2=1, p1=1, p2=0, s1=1, s2=1, logdet=false, activation::ActivationFunction=SigmoidLayer(), rb_activation::ActivationFunction=RELUlayer(), ndims=2)

# 1x1 Convolution and residual block for invertible layers
C = Conv1x1(n_in; freeze=freeze_conv)
RB = ResidualBlock(Int(n_in/2)+n_cond, n_hidden; n_out=n_in, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)

split_num = Int(round(n_in/2))
in_rb = n_in-split_num
out_rb = 2*split_num

RB = ResidualBlock(in_rb+n_cond, n_hidden; n_out=out_rb, activation=rb_activation, k1=k1, k2=k2, p1=p1, p2=p2, s1=s1, s2=s2, fan=true, ndims=ndims)

return ConditionalLayerGlow(C, RB, logdet, activation)
end
Expand Down Expand Up @@ -143,7 +148,10 @@ function backward(ΔY::AbstractArray{T, N}, Y::AbstractArray{T, N}, C::AbstractA

# Backpropagate RB
ΔX2_ΔC = L.RB.backward(tensor_cat(L.activation.backward(ΔS, S), ΔT), (tensor_cat(X2, C)))
ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=Int(size(ΔY)[N-1]/2))

n_in = size(ΔY)[N-1]
split_num = Int(round(n_in/2))
ΔX2, ΔC = tensor_split(ΔX2_ΔC; split_index=n_in-split_num)
ΔX2 += ΔY2

# Backpropagate 1x1 conv
Expand Down
312 changes: 106 additions & 206 deletions test/test_networks/test_conditional_glow_network.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,158 +5,7 @@
using InvertibleNetworks, LinearAlgebra, Test, Random

# Random seed
Random.seed!(2);

# Define network
nx = 32
ny = 32
nz = 32
n_in = 2
n_cond = 2
n_hidden = 4
batchsize = 2
L = 2
K = 2
split_scales = true
N = (nx,ny)
########################################### Test with split_scales = true N = (nx,ny) #########################
# Invertibility

# Network and input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)

Y, Cond = G.forward(X,Cond)
X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes

@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

###################################################################################################
# Test gradients are set and cleared
G.backward(Y, Y, Cond)

P = get_params(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, L*K*10+2)

clear_grad!(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, 0)


###################################################################################################
# Gradient test

function loss(G, X, Cond)
Y, ZC, logdet = G.forward(X, Cond)
f = -log_likelihood(Y) - logdet
ΔY = -∇log_likelihood(Y)
ΔX, X_ = G.backward(ΔY, Y, ZC)
return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad
end

# Gradient test w.r.t. input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)
X0 = rand(Float32, N..., n_in, batchsize)
Cond0 = rand(Float32, N..., n_cond, batchsize)

dX = X - X0

f0, ΔX = loss(G, X0, Cond0)[1:2]
h = 0.1f0
maxiter = 4
err1 = zeros(Float32, maxiter)
err2 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
f = loss(G, X0 + h*dX, Cond0)[1]
err1[j] = abs(f - f0)
err2[j] = abs(f - f0 - h*dot(dX, ΔX))
print(err1[j], "; ", err2[j], "\n")
global h = h/2f0
end

@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)


# Gradient test w.r.t. parameters
X = rand(Float32, N..., n_in, batchsize)
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
Gini = deepcopy(G0)

# Test one parameter from residual block and 1x1 conv
dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data
dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data

f0, ΔX, ΔW, Δv = loss(G0, X, Cond)
h = 0.1f0
maxiter = 4
err3 = zeros(Float32, maxiter)
err4 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv

f = loss(G0, X, Cond)[1]
err3[j] = abs(f - f0)
err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv))
print(err3[j], "; ", err4[j], "\n")
global h = h/2f0
end

@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)


N = (nx,ny,nz)
########################################### Test with split_scales = true N = (nx,ny,nz) #########################
# Invertibility

# Network and input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)

Y, Cond = G.forward(X,Cond)
X_ = G.inverse(Y,Cond) # saving the cond is important in split scales because of reshapes

@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

###################################################################################################
# Test gradients are set and cleared
G.backward(Y, Y, Cond)

P = get_params(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, L*K*10+2)

clear_grad!(G)
gsum = 0
for p in P
~isnothing(p.grad) && (global gsum += 1)
end
@test isequal(gsum, 0)


###################################################################################################
# Gradient test
Random.seed!(7);

function loss(G, X, Cond)
Y, ZC, logdet = G.forward(X, Cond)
Expand All @@ -166,62 +15,113 @@ function loss(G, X, Cond)
return f, ΔX, G.CL[1,1].RB.W1.grad, G.CL[1,1].C.v1.grad
end

# Gradient test w.r.t. input
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)
X0 = rand(Float32, N..., n_in, batchsize)
Cond0 = rand(Float32, N..., n_cond, batchsize)

dX = X - X0

f0, ΔX = loss(G, X0, Cond0)[1:2]
h = 0.1f0
maxiter = 4
err1 = zeros(Float32, maxiter)
err2 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
f = loss(G, X0 + h*dX, Cond0)[1]
err1[j] = abs(f - f0)
err2[j] = abs(f - f0 - h*dot(dX, ΔX))
print(err1[j], "; ", err2[j], "\n")
global h = h/2f0
function gradients_set(G, n_in,n_cond,N)
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)

XZ, CondZ = G.forward(X,Cond)
# Set gradients
G.backward(XZ, XZ, CondZ)

P = get_params(G)
gsum = 0
for p in P
~isnothing(p.grad) && (gsum += 1)
end
@test isequal(gsum, L*K*10+2)

clear_grad!(G)
gsum = 0
for p in P
~isnothing(p.grad) && (gsum += 1)
end
@test isequal(gsum, 0)
end

@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)


# Gradient test w.r.t. parameters
X = rand(Float32, N..., n_in, batchsize)
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
Gini = deepcopy(G0)

# Test one parameter from residual block and 1x1 conv
dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data
dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data

f0, ΔX, ΔW, Δv = loss(G0, X, Cond)
h = 0.1f0
maxiter = 4
err3 = zeros(Float32, maxiter)
err4 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv
# Define network
nx = 32
ny = 32
nz = 32
n_in = 3
n_cond = 2
n_hidden = 4
batchsize = 2
L = 2
K = 2

f = loss(G0, X, Cond)[1]
err3[j] = abs(f - f0)
err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv))
print(err3[j], "; ", err4[j], "\n")
global h = h/2f0
for split_scales in [false,true]
for N in [(nx),(nx,ny),(nx,ny,nz)]
println("Test with split_scales = $(split_scales) N = $(N) ")
########################################### Test with split_scales = true N = (nx,ny) #########################

# Network and inputs
G = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K; split_scales=split_scales, ndims=length(N))
X = rand(Float32, N..., n_in, batchsize)
Cond = rand(Float32, N..., n_cond, batchsize)

# Invertibility
XZ, CondZ = G.forward(X,Cond)
X_ = G.inverse(XZ, CondZ) # saving the cond output is important in split scales because of reshapes
@test isapprox(norm(X - X_)/norm(X), 0f0; atol=1f-5)

###################################################################################################
# Test gradients are set and cleared
gradients_set(G, n_in, n_cond,N)

###################################################################################################
# Gradient test w.r.t. input
X0 = rand(Float32, N..., n_in, batchsize)
Cond0 = rand(Float32, N..., n_cond, batchsize)

dX = X - X0

f0, ΔX = loss(G, X0, Cond0)[1:2]
h = 0.01f0
maxiter = 4
err1 = zeros(Float32, maxiter)
err2 = zeros(Float32, maxiter)

print("\nGradient test glow: input\n")
for j=1:maxiter
f = loss(G, X0 + h*dX, Cond0)[1]
err1[j] = abs(f - f0)
err2[j] = abs(f - f0 - h*dot(dX, ΔX))
print(err1[j], "; ", err2[j], "\n")
h = h/2f0
end

@test isapprox(err1[end] / (err1[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err2[end] / (err2[1]/4^(maxiter-1)), 1f0; atol=1f0)


# Gradient test w.r.t. parameters
G0 = NetworkConditionalGlow(n_in, n_cond, n_hidden, L, K;split_scales=split_scales,ndims=length(N))
Gini = deepcopy(G0)

# Test one parameter from residual block and 1x1 conv
dW = G.CL[1,1].RB.W1.data - G0.CL[1,1].RB.W1.data
dv = G.CL[1,1].C.v1.data - G0.CL[1,1].C.v1.data

f0, ΔX, ΔW, Δv = loss(G0, X, Cond)
h = 0.01f0
maxiter = 4
err3 = zeros(Float32, maxiter)
err4 = zeros(Float32, maxiter)

print("\nGradient test glow: parameter\n")
for j=1:maxiter
G0.CL[1,1].RB.W1.data = Gini.CL[1,1].RB.W1.data + h*dW
G0.CL[1,1].C.v1.data = Gini.CL[1,1].C.v1.data + h*dv

f = loss(G0, X, Cond)[1]
err3[j] = abs(f - f0)
err4[j] = abs(f - f0 - h*dot(dW, ΔW) - h*dot(dv, Δv))
print(err3[j], "; ", err4[j], "\n")
h = h/2f0
end

@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)
end
end

@test isapprox(err3[end] / (err3[1]/2^(maxiter-1)), 1f0; atol=1f0)
@test isapprox(err4[end] / (err4[1]/4^(maxiter-1)), 1f0; atol=1f0)