Skip to content

Commit

Permalink
changes for v0.3 (#94)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jan 14, 2025
1 parent 82e144b commit a84daf8
Show file tree
Hide file tree
Showing 16 changed files with 260 additions and 196 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tsunami"
uuid = "36e41bbe-399b-4a86-8623-faa02b4c2ac8"
authors = "Carlo Lucibello"
version = "0.2.0"
version = "0.3.0-DEV"

[deps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
Expand All @@ -14,7 +14,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand Down Expand Up @@ -42,7 +41,6 @@ Enzyme = "0.13.27"
EnzymeCore = "0.8.8"
Flux = "0.16"
Functors = "0.5.2"
GPUArraysCore = "0.1"
JLD2 = "0.5.10"
MLDataDevices = "1.6"
MLUtils = "0.4"
Expand Down
1 change: 0 additions & 1 deletion docs/src/api/foil.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,4 @@ The [`Foil`](@ref) is a minimalistic version of the [`Trainer`](@ref) that allow
```@docs
Foil
Tsunami.setup
Tsunami.setup_batch
```
26 changes: 12 additions & 14 deletions docs/src/api/hooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@ CollapsedDocStrings = true

# Hooks

Hooks are a way to extend the functionality of Tsunami. They are a way to inject custom code into the FluxModule or
into a Callback at various points in the training, testing, and validation loops.
Hooks are a way to extend the functionality of Tsunami. They are a way to inject custom code into the FluxModule or into a Callback at various points in the training, testing, and validation loops.

At a high level, and omitting function imputs and outputs, a simplified version of the [`Tsunami.fit!`](@ref) method looks like this:

Expand All @@ -22,15 +21,13 @@ function train_loop()
on_train_epoch_start()
set_learning_rate(lr_scheduler, epoch)

for batch in train_dataloader
on_train_batch_start()
for (batch, batch_idx) in enumerate(train_dataloader)
batch = transfer_batch_to_device(batch)
loss, pb = pullback(m -> train_step(m, batch), model)
on_before_backprop()
grad = pb()
on_before_update()
on_train_batch_start(batch, batch_idx)
out, grad = out_and_gradient(train_step, model, trainer, batch, batch_idx)
on_before_update(out, grad)
update!(opt_state, model, grad)
on_train_batch_end()
on_train_batch_end(out, batch, batch_idx)
if should_check_val
val_loop()
end
Expand All @@ -40,20 +37,21 @@ end

function val_loop()
on_val_epoch_start()
for batch in val_dataloader
on_val_batch_start()
for (batch, batch_idx) in val_dataloader
batch = transfer_batch_to_device(batch)
val_step(batch)
on_val_batch_end()
on_val_batch_start(batch, batch_idx)
out = val_step(model, trainer, batch, batch_idx)
on_val_batch_end(out, batch, batch_idx)
end
on_val_epoch_end()
end
```

Each `on_something` hook takes as input the model and the trainer.

## Hooks API

```@docs
Tsunami.on_before_backprop
Tsunami.on_before_update
Tsunami.on_train_batch_start
Tsunami.on_train_batch_end
Expand Down
138 changes: 77 additions & 61 deletions examples/pytorch-lighting/autoencoder.ipynb

Large diffs are not rendered by default.

39 changes: 24 additions & 15 deletions ext/TsunamiEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,35 @@ using Enzyme
using Functors: Functors
using Optimisers: Optimisers

function Tsunami.pullback_train_step(model::Duplicated, trainer::Trainer, batch, batch_idx::Int)
make_zero!(model.dval)
ad = Enzyme.set_runtime_activity(ReverseSplitWithPrimal)
# ad = ReverseSplitWithPrimal
args = (model, Const(trainer), Const(batch), Const(batch_idx))
forward, reverse = autodiff_thunk(ad, Const{typeof(train_step)}, Active, map(typeof, args)...)
tape, loss, _ = forward(Const(train_step), args...)
function pb()
reverse(Const(train_step), args..., one(loss), tape)
return model.dval
end
return loss, pb
end
# function Tsunami.pullback_train_step(model::Duplicated, trainer::Trainer, batch, batch_idx::Int)
# make_zero!(model.dval)
# ad = Enzyme.set_runtime_activity(ReverseSplitWithPrimal)
# # ad = ReverseSplitWithPrimal
# args = (model, Const(trainer), Const(batch), Const(batch_idx))
# forward, reverse = autodiff_thunk(ad, Const{typeof(train_step)}, Active, map(typeof, args)...)
# tape, loss, _ = forward(Const(train_step), args...)
# function pb()
# reverse(Const(train_step), args..., one(loss), tape)
# return model.dval
# end
# return loss, pb
# end

function Tsunami.gradient_train_step(model::Duplicated, trainer::Trainer, batch, batch_idx::Int)
make_zero!(model.dval)
ad = Enzyme.set_runtime_activity(ReverseWithPrimal)
args = (model, Const(trainer), Const(batch), Const(batch_idx))
ret = Enzyme.autodiff(ad, Const(train_step), Active, args...)
return ret[2], model.dval

out = Ref{Any}() # TODO crashes if I set `local out` instead of Ref
function f(model, trainer, batch, batch_idx)
out[] = train_step(model, trainer, batch, batch_idx)
loss = Tsunami.process_out_step(out[])
return loss
end

ret = Enzyme.autodiff(ad, Const(f), Active, args...)
# return ret[2], model.dval
return out[], model.dval
end

# We can't use Enzyme.make_zero! to reset Duplicated, as it complains about e.g. LayerNorm having immutable differentiable fields
Expand Down
3 changes: 1 addition & 2 deletions src/Tsunami.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ include("loggers/tensorboard.jl")

include("foil.jl")
export Foil
@compat(public, (setup, setup_batch))
@compat(public, (setup,))

include("trainer.jl")
export Trainer
Expand All @@ -62,7 +62,6 @@ include("show.jl")

include("hooks.jl")
@compat(public, (on_before_update,
on_before_backprop,
on_train_epoch_start,
on_train_epoch_end,
on_val_epoch_start,
Expand Down
20 changes: 20 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,23 @@ function fit(model::FluxModule, trainer::Trainer, args...; kws...)
return newmodel, trainer.fit_state
end

function setup_batch(foil::Foil, batch)
@warn "setup_batch is deprecated. Setup the dataloader with the Foil instead."
return batch |> to_precision(foil) |> to_device(foil)
end

##### v0.3 DEPRECATIONS #####
#TODO deprecate properly
# on_train_batch_end(model, trainer) = nothing
# on_train_batch_end(cb, model, trainer) = nothing
# on_val_batch_end(model, trainer) = nothing
# on_val_batch_end(cb, model, trainer) = nothing
# on_test_batch_end(model, trainer) = nothing
# on_test_batch_end(cb, model, trainer) = nothing
# on_before_backprop(model, trainer, loss) = nothing
# on_before_backprop(cb, model, trainer, loss) = nothing
# on_before_update(model, trainer, grad) = nothing
# on_before_update(cb, model, trainer, grad) = nothing



60 changes: 25 additions & 35 deletions src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ function fit!(model::FluxModule, trainer::Trainer, train_dataloader, val_dataloa
model = EnzymeCore.Duplicated(model)
end

# setup loaders
train_dataloader = setup_iterator(trainer.foil, train_dataloader)
val_dataloader = setup_iterator(trainer.foil, val_dataloader)

trainer.optimisers = optimisers
trainer.lr_schedulers = lr_schedulers

Expand All @@ -73,7 +77,7 @@ function fit!(model::FluxModule, trainer::Trainer, train_dataloader, val_dataloa
if val_dataloader !== nothing
check_val_step(model, trainer, first(val_dataloader))
end
return fit_state
return nothing
end

val_loop(model, trainer, val_dataloader; progbar_keep=false, progbar_print_epoch=true)
Expand All @@ -85,7 +89,7 @@ function fit!(model::FluxModule, trainer::Trainer, train_dataloader, val_dataloa
end

Flux.loadmodel!(model_orig, Flux.state(model))
return fit_state
return nothing
end

val_loop(m::EnzymeCore.Duplicated, args...; kws...) = val_loop(m.val, args...; kws...)
Expand All @@ -102,17 +106,15 @@ function val_loop(model::FluxModule, trainer::Trainer, val_dataloader; progbar_o
valprogressbar = Progress(_length(val_dataloader); desc=progbar_desc,
showspeed=true, enabled=trainer.progress_bar, color=:green, offset=progbar_offset, keep=progbar_keep)
for (batch_idx, batch) in enumerate(val_dataloader)
fit_state.batchsize = MLUtils.numobs(batch)

fit_state.batchsize = MLUtils.numobs(batch)
hook(on_val_batch_start, model, trainer, batch, batch_idx)

batch = setup_batch(trainer.foil, batch)
val_step(model, trainer, batch, batch_idx)
out = val_step(model, trainer, batch, batch_idx)
ProgressMeter.next!(valprogressbar,
showvalues = values_for_val_progressbar(trainer.metalogger),
valuecolor = :green)

hook(on_val_batch_end, model, trainer)
hook(on_val_batch_end, model, trainer, out, batch, batch_idx)
end

fit_state.stage = :val_epoch_end
Expand Down Expand Up @@ -142,18 +144,12 @@ function train_loop(model, trainer::Trainer, train_dataloader, val_dataloader)
for (batch_idx, batch) in enumerate(train_dataloader)
fit_state.step += 1
fit_state.batchsize = MLUtils.numobs(batch)

hook(on_train_batch_start, model, trainer, batch, batch_idx)

batch = setup_batch(trainer.foil, batch)
out, grad = gradient_train_step(model, trainer, batch, batch_idx)

loss, pb = pullback_train_step(model, trainer, batch, batch_idx)
hook(on_before_backprop, model, trainer, loss)
grad = pb()
## Alternative directly computing the gradient
# loss, grad = gradient_train_step(model, trainer, batch, batch_idx)

hook(on_before_update, model, trainer, grad)
hook(on_before_update, model, trainer, out, grad)

update!(trainer.optimisers, model, grad)

Expand All @@ -168,7 +164,7 @@ function train_loop(model, trainer::Trainer, train_dataloader, val_dataloader)
keep = fit_state.should_stop || fit_state.epoch == trainer.max_epochs
)

hook(on_train_batch_end, model, trainer)
hook(on_train_batch_end, model, trainer, out, batch, batch_idx)

fit_state.should_stop && break
end
Expand Down Expand Up @@ -196,22 +192,14 @@ end
update!(optimisers, m::EnzymeCore.Duplicated, grad) = update!(optimisers, m.val, grad)
update!(optimisers, m::FluxModule, grad) = Optimisers.update!(optimisers, m, grad)

function pullback_train_step(model::FluxModule, trainer::Trainer, batch, batch_idx::Int)
loss, z_pb = Zygote.pullback(model) do model
loss = train_step(model, trainer, batch, batch_idx)
return loss
end
# zygote returns a Ref with immutable, so we need to unref it
pb = () -> unref(z_pb(one(loss))[1])
return loss, pb
end

function gradient_train_step(model::FluxModule, trainer::Trainer, batch, batch_idx::Int)
local out
loss, z_grad = Zygote.withgradient(model) do model
loss = train_step(model, trainer, batch, batch_idx)
out = train_step(model, trainer, batch, batch_idx)
loss = process_out_step(out)
return loss
end
return loss, unref(z_grad[1])
return out, unref(z_grad[1])
end

# TODO remove when Optimisers.jl is able to handle gradients with (nested) Refs
Expand All @@ -229,6 +217,9 @@ function process_out_configure_optimisers(out)
return opt, lr_scheduler
end

process_out_step(loss::Number) = loss
process_out_step(out::NamedTuple) = out.loss

function print_fit_initial_summary(model, trainer)
cuda_available = is_cuda_functional()
amdgpu_available = is_amdgpu_functional()
Expand Down Expand Up @@ -281,6 +272,7 @@ Dict{String, Float64} with 1 entry:
"""
function test(model::FluxModule, trainer::Trainer, dataloader)
model = setup(trainer.foil, model)
dataloader = setup_iterator(trainer.foil, dataloader)
return test_loop(model, trainer, dataloader; progbar_keep=true)
end

Expand All @@ -291,7 +283,6 @@ function test_loop(model, trainer, dataloader; progbar_offset = 0, progbar_keep

hook(on_test_epoch_start, model, trainer)


testprogressbar = Progress(_length(dataloader); desc="Testing: ",
showspeed=true, enabled=trainer.progress_bar,
color=:green, offset=progbar_offset, keep=progbar_keep)
Expand All @@ -300,14 +291,13 @@ function test_loop(model, trainer, dataloader; progbar_offset = 0, progbar_keep

hook(on_test_batch_start, model, trainer, batch, batch_idx)

batch = setup_batch(trainer.foil, batch)
test_step(model, trainer, batch, batch_idx)
out = test_step(model, trainer, batch, batch_idx)

ProgressMeter.next!(testprogressbar,
showvalues = values_for_val_progressbar(trainer.metalogger),
valuecolor = :green
)
valuecolor = :green)

hook(on_test_batch_end, model, trainer)
hook(on_test_batch_end, model, trainer, out, batch, batch_idx)
end

fit_state.stage = :test_epoch_end
Expand Down
18 changes: 14 additions & 4 deletions src/fluxmodule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,13 @@ for epoch in 1:epochs
end
```
The output can be either a scalar or a named tuple:
- If a scalar is returned, it is assumed to be the loss.
- If a named tuple is returned, it has to contain the `loss` field.
The output can be accessed in hooks such as [`on_before_update`](@ref) or [`on_train_batch_end`](@ref).
# Examples
```julia
Expand Down Expand Up @@ -175,6 +182,9 @@ validation epoch.
A `Model <: FluxModule` should implement either
`val_step(model::Model, trainer, batch)` or `val_step(model::Model, trainer, batch, batch_idx)`.
Optionally, the method can return a scalar or a named tuple, to be used in hooks such as
[`on_val_batch_end`](@ref).
See also [`train_step`](@ref).
# Examples
Expand Down Expand Up @@ -212,16 +222,16 @@ end
check_train_step(m::EnzymeCore.Duplicated, args...) = check_train_step(m.val, args...)

function check_train_step(m::FluxModule, trainer, batch)
batch = setup_batch(trainer.foil, batch)
out = train_step(m, trainer, batch, 1)
losserrmsg = "The output of `train_step` has to be a scalar."
@assert out isa Number losserrmsg
@assert out isa Union{Number,NamedTuple} "The output of `train_step` has to be a scalar or named tuple."
if out isa NamedTuple
@assert haskey(out, :loss) "A named tuple output of `train_step` has to contain the `loss` field."
end
end

check_val_step(m::EnzymeCore.Duplicated, args...) = check_val_step(m.val, args...)

function check_val_step(m::FluxModule, trainer, batch)
batch = setup_batch(trainer.foil, batch)
val_step(m, trainer, batch, 1)
@assert true
end
Expand Down
Loading

0 comments on commit a84daf8

Please sign in to comment.