Skip to content

Commit

Permalink
add a flow matching example (#91)
Browse files Browse the repository at this point in the history
* add a flow matching example

* finish up

* cleanup

* fx path
  • Loading branch information
CarloLucibello authored Jan 6, 2025
1 parent 1b7cd5c commit 82e144b
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 95 deletions.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
[deps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
HuggingFaceDatasets = "d94b9a45-fdf5-4270-b024-5cbb9ef7117d"
ImageCore = "a09fc81d-aa75-5fe9-8630-4744c3626534"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tsunami = "36e41bbe-399b-4a86-8623-faa02b4c2ac8"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
114 changes: 114 additions & 0 deletions examples/FlowMatching_Chessboard/flow_matching_chessboard.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# # Flow Matching Example: Chessboard

# This example is ported the pytorch notebook at
# https://github.com/facebookresearch/flow_matching/blob/main/examples/2d_flow_matching.ipynb

# We train and evaluate a simple 2D FM model with a linear scheduler.

# ## Imports

using Statistics, Random
using Tsunami, Flux, Optimisers
using Plots, ConcreteStructs
using MLUtils
using Enzyme
using CUDA, cuDNN

# ## Data Generation

# We define a generator of 2D points on a chessboard pattern.

function inf_train_gen(batch_size::Int = 200)
n = 4
x1 = rand(Float32, batch_size) .* n
x2 = rand(Float32, batch_size) .+ rand(0:2:n-2, batch_size)
x2 = x2 .+ floor.(x1) .% 2
data = vcat(x1', x2')
data .-= n / 2
return data
end

# Let's visualize a data sample.

function plot_checkboard(points)
scatter(points[1,:], points[2,:], xlim=(-2,2), ylim=(-2,2),
markersize=3, legend=false, widen=false, framestyle=:box)
end

data = inf_train_gen(1000)
plot_checkboard(data)

# ## Model Definition
@concrete struct FlowModel <: FluxModule
net
hparams
end

function FlowModel(; input_dim::Int = 2, time_dim::Int = 1, hidden_dim::Int = 128, lr = 0.001)
net = Chain(
Dense(input_dim + time_dim, hidden_dim, elu),
Dense(hidden_dim, hidden_dim, elu),
Dense(hidden_dim, hidden_dim, elu),
Dense(hidden_dim, input_dim)
)
hparams = (; lr)
return FlowModel(net, hparams)
end

(m::FlowModel)(x::AbstractMatrix, t::Number) = m(x, fill_like(x, t, size(x, 2)))
(m::FlowModel)(x::AbstractMatrix, t::AbstractVector) = m(x, reshape(t, 1, :))
(m::FlowModel)(x::AbstractMatrix, t::AbstractMatrix) = m.net(vcat(x, t))

function Tsunami.configure_optimisers(m::FlowModel, trainer)
opt = Optimisers.setup(Optimisers.Adam(m.hparams.lr), m)
return opt
end

function Tsunami.train_step(m::FlowModel, trainer, batch, batch_idx)
x1 = batch
batch_size = size(x1, 2)
x0 = randn_like(x1)
t = rand_like(x1, (1, batch_size))
xt = @. (1 - t) * x0 + t * x1
v = x1 .- x0
= m(xt, t)
loss = Flux.mse(v̂, v)
Tsunami.log(trainer, "loss/train", loss)
return loss
end

function train(; lr = 1e-4, batch_size = 256, iterations = 50000, hidden_dim = 512)
train_loader = (inf_train_gen(batch_size) for _ in 1:iterations)
# train_loader = (make_moons(batch_size, 0.05)[1] for _ in 1:iterations)

model = FlowModel(; input_dim=2, hidden_dim, lr)
trainer = Trainer(max_epochs=1, log_every_n_steps=50,
accelerator=:auto, autodiff=:enzyme)
Tsunami.fit!(model, trainer, train_loader)
return model
end

# Let's train the model.

model = train(lr=1e-3, hidden_dim=512, batch_size=4096, iterations=20000)

# ## Sampling

function step(m::FlowModel, x_t::AbstractMatrix, t_start::Number, t_end::Number)
vt = m(x_t, t_start)
xhalf = @. x_t + vt * (t_end - t_start) / 2
vhalf = m(xhalf, (t_start + t_end) / 2)
return @. x_t + vhalf * (t_end - t_start)
end

function sample(m::FlowModel; n::Int , steps::Int)
x = randn(Float32, 2, n)
ts = Float32.(range(0, 1, length=steps+1))
for i in 1:steps
x = step(m, x, ts[i], ts[i+1])
end
return x
end

samples = sample(m, n=1000, steps=10)
plot_checkboard(samples)
File renamed without changes.
File renamed without changes.
89 changes: 0 additions & 89 deletions examples/bert_mnli_original.jl

This file was deleted.

File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ if TEST_GPU
end

@testset "Examples" begin
include(joinpath(@__DIR__, "..", "examples/mlp_mnist.jl"))
include(joinpath(@__DIR__, "..", "examples/MLP_MNIST/mlp_mnist.jl"))
end

0 comments on commit 82e144b

Please sign in to comment.