Skip to content

Commit

Permalink
update examples + fix callbacks (#96)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Jan 24, 2025
1 parent 9dca3f0 commit 121db93
Show file tree
Hide file tree
Showing 20 changed files with 246 additions and 1,027 deletions.
17 changes: 14 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,21 @@
# Tsunami Release Notes

## v0.2.1


## v0.3.0

**Breaking changes:**
- `fit!` returns `nothing` instead of a `FitState` object. The `FitState` object can be accessed via `trainer.fit_state`.
- `on_before_pullback` has been removed. Use `on_train_batch_start` instead.
- `on_*_batch_start` now receives the batch on device.
- Some of the hooks now take more inputs.

**Highlights:**
- Users should no longer assume that `fit!` returns a `FitState` object.
In the future, `fit!` will return nothing. The `FitState` object can be accessed via `trainer.fit_state`.

- Now Tsunami uses `MLDataDevices.DeviceIterator` to wrap dataloaders for more efficient device memory management.

- `training_step`, `validation_step`, and `test_step` can now return a named tuple
for flebility. One of the fields of the named tuple should be `loss` which is used to compute the loss value.

## v0.2.0 - 2025-01-03

Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tsunami"
uuid = "36e41bbe-399b-4a86-8623-faa02b4c2ac8"
authors = "Carlo Lucibello"
version = "0.3.0-DEV"
version = "0.3.0"

[deps]
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
Expand Down
21 changes: 18 additions & 3 deletions docs/src/api/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,19 @@ CollapsedDocStrings = true

Callbacks are functions that are called at certain points in the training process. They are useful for logging, early stopping, and other tasks.

Callbacks are passed to the [`Tsunami.fit!`](@ref) function:
Callbacks are passed to the [`Trainer`](@ref) constructor:

```julia
callback1 = Checkpointer(...)
fit(..., callbacks = [callback1, ...])
trainer = Trainer(..., callbacks = [callback1, ...])
```

Callback implement their functionalities thanks to the hooks described in the [Hooks](@ref) section of the documentation.

## Available Callbacks

A few callbacks are provided by Tsunami.

### Checkpoints

Callbacks for saving and loading the model and optimizer state.
Expand All @@ -28,6 +30,19 @@ Tsunami.load_checkpoint

## Writing Custom Callbacks

Users can write their own callbacks by defining customs types and implementing the hooks they need. For example

```julia
struct MyCallback end

function Tsunami.on_train_epoch_end(cb::MyCallback, model, trainer)
fit_state = trainer.fit_state # contains info about the training status
# do something
end

trainer = Trainer(..., callbacks = [MyCallback()])
```

See the implementation of [`Checkpointer`](@ref) and the
[Hooks](@ref) section of the documentation for more information
on how to write custom callbacks.
on how to write custom callbacks. Also, the [https://github.com/CarloLucibello/Tsunami.jl/tree/main/examples](examples) folder contains some examples of custom callbacks.
Binary file removed docs/src/assets/readme_output.png
Binary file not shown.
109 changes: 0 additions & 109 deletions examples/Bert_MNLI/bert_mnli.jl

This file was deleted.

15 changes: 15 additions & 0 deletions examples/MLP_MNIST/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[deps]
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
ParameterSchedulers = "d7d3b36b-41b8-4d0d-a2bf-768c6151755e"
Tsunami = "36e41bbe-399b-4a86-8623-faa02b4c2ac8"

[compat]
Flux = "0.16"
MLDatasets = "0.7"
MLUtils = "0.4"
Optimisers = "0.4"
ParameterSchedulers = "0.4"
Tsunami = "0.3"
37 changes: 25 additions & 12 deletions examples/MLP_MNIST/mlp_mnist.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
# # MNIST MLP Example
# This example demonstrates how to train a simple MLP model
# for image classification on the MNIST dataset using Tsunami.

# ## Setup
using Flux, Optimisers, Tsunami, MLDatasets
using MLUtils: MLUtils, DataLoader, flatten, mapobs, splitobs
import ParameterSchedulers
## Uncomment one of the following lines for GPU support
# using CUDA
# using AMDGPU
# using Metal

# Uncomment one of the following lines for GPU support

## using CUDA
## using AMDGPU
## using Metal

# ## Model Definition

struct MLP{T} <: FluxModule
net::T
Expand All @@ -19,9 +28,7 @@ function MLP()
return MLP(net)
end

function (m::MLP)(x)
m.net(x)
end
(m::MLP)(x) = m.net(x)

function Tsunami.train_step(m::MLP, trainer, batch)
x, y = batch
Expand Down Expand Up @@ -55,6 +62,9 @@ function Tsunami.configure_optimisers(m::MLP, trainer)
return opt, lr_scheduler
end


# ## Data Preparation

train_data = mapobs(batch -> (batch[1], Flux.onehotbatch(batch[2], 0:9)), MNIST(:train))
train_data, val_data = splitobs(train_data, at = 0.9)
test_data = mapobs(batch -> (batch[1], Flux.onehotbatch(batch[2], 0:9)), MNIST(:test))
Expand All @@ -63,16 +73,17 @@ train_loader = DataLoader(train_data, batchsize=128, shuffle=true)
val_loader = DataLoader(val_data, batchsize=128, shuffle=true)
test_loader = DataLoader(test_data, batchsize=128)

# CREATE MODEL
# ## Training
# First, we create the model:

model = MLP()

# DRY RUN FOR DEBUGGING
# Now we do a fast dev run to make sure everything is working:

trainer = Trainer(fast_dev_run=true, accelerator=:auto)
Tsunami.fit!(model, trainer, train_loader, val_loader)

# TRAIN FROM SCRATCH
# We then train the model for real:

Tsunami.seed!(17)
trainer = Trainer(max_epochs = 3,
Expand All @@ -84,7 +95,8 @@ Tsunami.fit!(model, trainer, train_loader, val_loader)
@assert trainer.fit_state.step == 1266
run_dir= trainer.fit_state.run_dir

# RESUME TRAINING
# We can also resume the training from the last checkpoint:

trainer = Trainer(max_epochs = 5,
default_root_dir = @__DIR__,
accelerator = :auto,
Expand All @@ -97,5 +109,6 @@ model = MLP()
Tsunami.fit!(model, trainer, train_loader, val_loader; ckpt_path)
@assert trainer.fit_state.step == 2110

# TEST
# ## Testing

test_results = Tsunami.test(model, trainer, test_loader)
24 changes: 24 additions & 0 deletions examples/VAE_MNIST/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
[deps]
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ImageShow = "4e3cecfd-b093-5904-9786-8bbb286a6a31"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Tsunami = "36e41bbe-399b-4a86-8623-faa02b4c2ac8"


[compat]
Flux = "0.16"
ImageShow = "0.3"
LinearAlgebra = "1"
MLDatasets = "0.7"
MLUtils = "0.4"
Optimisers = "0.4"
Random = "1"
Statistics = "1"
Tsunami = "0.3"
Loading

0 comments on commit 121db93

Please sign in to comment.