Skip to content

Commit

Permalink
added one more example on train test regression
Browse files Browse the repository at this point in the history
  • Loading branch information
florencejt committed Jan 11, 2024
1 parent 64b9e37 commit 0efbe2e
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
Training multiple models in a loop: k-fold regression
Comparing All Fusion Models
====================================================================
Welcome to the "Comparing Multiple K-Fold Trained Fusion Models" tutorial! In this tutorial, we'll explore how to train and compare multiple fusion models for a regression task using k-fold cross-validation with multimodal tabular data. This tutorial is designed to help you understand and implement key features, including:
Welcome to the "Comparing All Fusion Models" tutorial! In this tutorial, we'll explore how to train and compare multiple fusion models for a regression task using k-fold cross-validation with multimodal tabular data. This tutorial is designed to help you understand and implement key features, including:
- 📥 Importing fusion models based on modality types.
- 🚲 Setting training parameters for your models
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/training_and_testing/README.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
.. _train_test_examples:

Training and Testing Examples
Training and Testing
==========================================

These are examples of how to train and validate fusion models with Fusilli.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Binary Classification: Training a K-Fold Model
K-Fold Cross-Validation: Binary Classification
======================================================
🚀 In this tutorial, we'll explore binary classification using K-fold cross validation.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
"""
Train/Test split: Regression
======================================================
🚀 In this tutorial, we'll explore regression using a train/test split.
Specifically, we're using the :class:`~.TabularCrossmodalMultiheadAttention` model.
Key Features:
- 📥 Importing a model based on its path.
- 🧪 Training and testing a model with train/test split.
- 📈 Plotting the loss curves of each fold.
- 📊 Visualising the results of a single train/test model using the :class:`~.RealsVsPreds` class.
"""

import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import os

from docs.examples import generate_sklearn_simulated_data
from fusilli.data import prepare_fusion_data
from fusilli.eval import RealsVsPreds
from fusilli.train import train_and_save_models

# sphinx_gallery_thumbnail_number = -1

# %%
# 1. Import the fusion model 🔍
# --------------------------------
# We're importing only one model for this example, the :class:`~.TabularCrossmodalMultiheadAttention` model.
# Instead of using the :func:`~fusilli.utils.model_chooser.import_chosen_fusion_models` function, we're importing the model directly like with any other library method.


from fusilli.fusionmodels.tabularfusion.crossmodal_att import (
TabularCrossmodalMultiheadAttention,
)

# %%
# 2. Set the training parameters 🎯
# -----------------------------------
# Now we're configuring our training parameters.
#
# For training and testing, the necessary parameters are:
# - Paths to the input data files.
# - Paths to the output directories.
# - ``prediction_task``: the type of prediction to be performed. This is either ``regression``, ``binary``, or ``classification``.
#
# Some optional parameters are:
#
# - ``kfold``: a boolean of whether to use k-fold cross-validation (True) or not (False). By default, this is set to False.
# - ``num_folds``: the number of folds to use. It can't be ``k=1``.
# - ``wandb_logging``: a boolean of whether to log the results using Weights and Biases (True) or not (False). Default is False.
# - ``test_size``: the proportion of the dataset to include in the test split. Default is 0.2.
# - ``batch_size``: the batch size to use for training. Default is 8.
# - ``multiclass_dimensions``: the number of classes to use for multiclass classification. Default is None unless ``prediction_task`` is ``multiclass``.
# - ``max_epochs``: the maximum number of epochs to train for. Default is 1000.

# Regression task
prediction_task = "regression"

# Set the batch size
batch_size = 32

# Setting output directories
output_paths = {
"losses": "loss_logs/one_model_regression_traintest",
"checkpoints": "checkpoints/one_model_regression_traintest",
"figures": "figures/one_model_regression_traintest",
}

# Create the output directories if they don't exist
for path in output_paths.values():
os.makedirs(path, exist_ok=True)

# Clearing the loss logs directory (only for the example notebooks)
for dir in os.listdir(output_paths["losses"]):
# remove files
for file in os.listdir(os.path.join(output_paths["losses"], dir)):
os.remove(os.path.join(output_paths["losses"], dir, file))
# remove dir
os.rmdir(os.path.join(output_paths["losses"], dir))

# %%
# 3. Generating simulated data 🔮
# --------------------------------
# Time to create some simulated data for our models to work their wonders on.
# This function also simulated image data which we aren't using here.

tabular1_path, tabular2_path = generate_sklearn_simulated_data(prediction_task,
num_samples=500,
num_tab1_features=10,
num_tab2_features=20)

data_paths = {
"tabular1": tabular1_path,
"tabular2": tabular2_path,
"image": "",
}

# %%
# 4. Training the fusion model 🏁
# --------------------------------------
# Now we're ready to train our model. We're using the :func:`~fusilli.train.train_and_save_models` function to train our model.
#
# First we need to create a data module using the :func:`~fusilli.data.prepare_fusion_data` function.
# This function takes the following parameters:
#
# - ``prediction_task``: the type of prediction to be performed.
# - ``fusion_model``: the fusion model to be trained.
# - ``data_paths``: the paths to the input data files.
# - ``output_paths``: the paths to the output directories.
#
# Then we pass the data module and the fusion model to the :func:`~fusilli.train.train_and_save_models` function.
# We're not using checkpointing for this example, so we set ``enable_checkpointing=False``. We're also setting ``show_loss_plot=True`` to plot the loss curve.


fusion_model = TabularCrossmodalMultiheadAttention

print("method_name:", fusion_model.method_name)
print("modality_type:", fusion_model.modality_type)
print("fusion_type:", fusion_model.fusion_type)

dm = prepare_fusion_data(prediction_task=prediction_task,
fusion_model=fusion_model,
data_paths=data_paths,
output_paths=output_paths,
batch_size=batch_size)

# train and test
single_model_list = train_and_save_models(
data_module=dm,
fusion_model=fusion_model,
enable_checkpointing=False, # False for the example notebooks
show_loss_plot=True,
metrics_list=["r2", "mae", "mse"]
)

# %%
# 6. Plotting the results 📊
# ----------------------------
# Now we're ready to plot the results of our model.
# We're using the :class:`~.RealsVsPreds` class to plot the confusion matrix.

reals_preds_fig = RealsVsPreds.from_final_val_data(
single_model_list
)
plt.show()
3 changes: 3 additions & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
.. _install_instructions:


How to Install
==============

Expand Down
6 changes: 6 additions & 0 deletions docs/quick_start.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ This script provides a simple setup to train a model using ``fusilli`` on a sing

This code showcases the necessary steps to execute Fusilli on a single dataset.

**Before you run this, you need to:**

1. Install ``fusilli`` (see :ref:`install_instructions`).
2. Prepare your data and specify the paths to your data (see :ref:`data-loading`).
3. Specify output file paths (see :ref:`experiment-set-up`).


Usage Example
-------------
Expand Down

0 comments on commit 0efbe2e

Please sign in to comment.