Skip to content

Commit

Permalink
fix enzyme pullback (#84)
Browse files Browse the repository at this point in the history
* fix enzyme pullback

* fix docs errors
  • Loading branch information
CarloLucibello authored Jan 3, 2025
1 parent bee89f4 commit 1dcaa0c
Show file tree
Hide file tree
Showing 10 changed files with 46 additions and 64 deletions.
14 changes: 14 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Tsunami Release Notes

# v0.2 - December 2024

Breaking changes:
- Devices indexing now starts from 1, as in `MLDataDevices.gpu_device`.

Highlights:
- Updated to Flux v0.16.
- Models (i.e. subtypes of `FluxModule`) are now not required to be mutable.
- `Tsunami.fit` is deprecated in favour of `Tsunami.fit!`.
- Added Trainer option to use `Enzyme` for automatic differentiation .
- Improved test infrastructure.
- Improved documentation.
6 changes: 3 additions & 3 deletions docs/src/api/hooks.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ into a Callback at various points in the training, testing, and validation loops
At a high level, and omitting function imputs and outputs, a simplified version of the [`Tsunami.fit!`](@ref) method looks like this:

```julia
function fit()
function fit!()
configure_optimizers()

for epoch in epochs
Expand All @@ -25,9 +25,9 @@ function train_loop()
for batch in train_dataloader
on_train_batch_start()
batch = transfer_batch_to_device(batch)
loss, pb = pullback(m -> train_step(model, batch), model)
loss, pb = pullback(m -> train_step(m, batch), model)
on_before_backprop()
grad = pb(1)
grad = pb()
on_before_update()
update!(opt_state, model, grad)
on_train_batch_end()
Expand Down
2 changes: 1 addition & 1 deletion docs/src/guides.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ To select a specific GPU, use the `devices` keyword argument:
trainer = Trainer(devices = [1])
```

Devices are indexed starting from 1, as in the `MLDataDevices.get_device` method used by Flux.
Devices are indexed starting from 1, as in the `MLDataDevices.gpu_device` method used by Flux.

## Selecting an automatic differentiation engine

Expand Down
33 changes: 14 additions & 19 deletions ext/TsunamiEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,18 @@ using Enzyme
using Functors: Functors
using Optimisers: Optimisers

# function Tsunami.pullback_train_step(model::Duplicated, trainer::Trainer, batch, batch_idx::Int)
# make_zero!(model.dval)
# ad = Enzyme.set_runtime_activity(Enzyme.ReverseSplitWithPrimal)
# args = (model, Const(trainer), Const(batch), Const(batch_idx))
# forward, reverse = autodiff_thunk(ad, Const{typeof(train_step)}, Active, map(typeof, args)...)
# tape, loss, _ = forward(Const(train_step), args...)
# pb = () -> reverse(Const(train_step), args..., one(loss), tape)
# return loss, pb
# end

function Tsunami.pullback_train_step(model::Duplicated, trainer::Trainer, batch, batch_idx::Int)
function Tsunami.pullback_train_step(model::Duplicated, trainer::Trainer, batch, batch_idx::Int)
make_zero!(model.dval)
ad = Enzyme.set_runtime_activity(ReverseWithPrimal)
ad = Enzyme.set_runtime_activity(ReverseSplitWithPrimal)
# ad = ReverseSplitWithPrimal
args = (model, Const(trainer), Const(batch), Const(batch_idx))
ret = Enzyme.autodiff(ad, Const(train_step), Active, args...)
pb = () -> model.dval
return ret[2], pb
forward, reverse = autodiff_thunk(ad, Const{typeof(train_step)}, Active, map(typeof, args)...)
tape, loss, _ = forward(Const(train_step), args...)
function pb()
reverse(Const(train_step), args..., one(loss), tape)
return model.dval
end
return loss, pb
end

function Tsunami.gradient_train_step(model::Duplicated, trainer::Trainer, batch, batch_idx::Int)
Expand All @@ -36,10 +31,10 @@ end
make_zero!(model) = Functors.fmapstructure(make_zero_inner!, model)

function make_zero_inner!(x::AbstractArray{<:Number})
Optimisers.isnumeric(x) || return
Optimisers.maywrite(x) || error("can't handle this")
fill!(x, zero(eltype(x)))
nothing
Optimisers.isnumeric(x) || return
Optimisers.maywrite(x) || error("can't handle this")
fill!(x, zero(eltype(x)))
nothing
end

make_zero_inner!(x) = nothing # any other Functors leaf type
Expand Down
5 changes: 0 additions & 5 deletions src/ProgressMeter/ProgressMeter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,11 +409,6 @@ function speedstring(sec_per_iter)
return " >100 d/it"
end

"""
rewind(p::AbstractProgress)
Rewinds the cursor to the beginning of the progress bar.
"""
function rewind(p::AbstractProgress)
print(p.output, "\r\u1b[A" ^ (p.offset + p.numprintedvalues))
end
Expand Down
6 changes: 3 additions & 3 deletions src/fit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -146,10 +146,10 @@ function train_loop(model, trainer::Trainer, train_dataloader, val_dataloader)
batch = setup_batch(trainer.foil, batch)

loss, pb = pullback_train_step(model, trainer, batch, batch_idx)

hook(on_before_backprop, model, trainer, loss)

grad = pb()
## Alternative directly computing the gradient
# loss, grad = gradient_train_step(model, trainer, batch, batch_idx)

hook(on_before_update, model, trainer, grad)

Expand Down Expand Up @@ -205,7 +205,7 @@ function pullback_train_step(model::FluxModule, trainer::Trainer, batch, batch_i
end

function gradient_train_step(model::FluxModule, trainer::Trainer, batch, batch_idx::Int)
loss, z_grad = Zygote.gradient(model) do model
loss, z_grad = Zygote.withgradient(model) do model
loss = train_step(model, trainer, batch, batch_idx)
return loss
end
Expand Down
13 changes: 0 additions & 13 deletions src/fluxmodule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,6 @@ function test_step(model::FluxModule, trainer, batch)
return nothing
end

"""
copy!(dest::FluxModule, src::FluxModule)
Shallow copy of all fields of `src` to `dest`.
"""
function Base.copy!(dest::T1, src::T2) where {T1 <: FluxModule, T2 <: FluxModule}
@assert fieldnames(T1) == fieldnames(T2) "The two structs have different fields."
for f in fieldnames(T1)
setfield!(dest, f, getfield(src, f))
end
return dest
end

check_train_step(m::EnzymeCore.Duplicated, args...) = check_train_step(m.val, args...)

function check_train_step(m::FluxModule, trainer, batch)
Expand Down
4 changes: 2 additions & 2 deletions src/hooks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ on_before_update(cb, model, trainer, grad) = nothing
"""
on_before_backprop([callback,] model, trainer, loss)
Called after the model's forward, where also the pullback is computed,
but before the call to the pullback (the backward pass).
Called after the model's forward, where also the pullback is created,
but before the call to the pullback (the backward pass computing the gradient).
"""
on_before_backprop(model, trainer, loss) = nothing
on_before_backprop(cb, model, trainer, loss) = nothing
Expand Down
12 changes: 6 additions & 6 deletions src/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ function compact_show(io::IO, x)
show(IOContext(io, :limit => true, :compact => true), x)
end

"""
compact_typename(x::T) -> String
compact_typename(T) -> String
# """
# compact_typename(x::T) -> String
# compact_typename(T) -> String

Return a compact string representation of the type `T` of `x`.
Keep only the name and `T`'s parameters, discarding their own parameters.
"""
# Return a compact string representation of the type `T` of `x`.
# Keep only the name and `T`'s parameters, discarding their own parameters.
# """
compact_typename(x::T) where T = compact_typename(T)

function compact_typename(T::DataType)
Expand Down
15 changes: 3 additions & 12 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,8 @@ roundval(x::AbstractFloat) = roundval(Float64(x))
roundval(x::Int) = x
roundval(x::NamedTuple) = map(roundval, x)


"""
dir_with_version(dir::String)
Append a version number to `dir`.
"""
function dir_with_version(dir)
# Append a version number to `dir`.
function dir_with_version(dir::String)
i = 1
outdir = dir * "_$i"
while isdir(outdir)
Expand Down Expand Up @@ -71,11 +66,7 @@ function seed!(seed::Int)
end
end

"""
_length(x) -> Int
Return the length of `x` if defined, otherwise return -1.
"""
# Return the length of `x` if defined, otherwise return -1.
function _length(x)
try
return length(x)
Expand Down

0 comments on commit 1dcaa0c

Please sign in to comment.