diff --git a/README.md b/README.md
index 2aab7d86..33908c08 100644
--- a/README.md
+++ b/README.md
@@ -5,130 +5,597 @@
![MacOS](https://img.shields.io/badge/mac%20os-000000?style=for-the-badge&logo=apple&logoColor=white)
[![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](https://opensource.org/licenses/Apache-2.0)
-![PyBadge](https://github.com/luxonis/luxonis-train/blob/main/media/pybadge.svg)
+![PyBadge](https://img.shields.io/pypi/pyversions/luxonis-train?logo=data:image/svg+xml%3Bbase64,PHN2ZyB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciIHZpZXdCb3g9IjAgMCAxMDAgMTAwIj4KICA8ZGVmcz4KICAgIDxsaW5lYXJHcmFkaWVudCBpZD0icHlZZWxsb3ciIGdyYWRpZW50VHJhbnNmb3JtPSJyb3RhdGUoNDUpIj4KICAgICAgPHN0b3Agc3RvcC1jb2xvcj0iI2ZlNSIgb2Zmc2V0PSIwLjYiLz4KICAgICAgPHN0b3Agc3RvcC1jb2xvcj0iI2RhMSIgb2Zmc2V0PSIxIi8+CiAgICA8L2xpbmVhckdyYWRpZW50PgogICAgPGxpbmVhckdyYWRpZW50IGlkPSJweUJsdWUiIGdyYWRpZW50VHJhbnNmb3JtPSJyb3RhdGUoNDUpIj4KICAgICAgPHN0b3Agc3RvcC1jb2xvcj0iIzY5ZiIgb2Zmc2V0PSIwLjQiLz4KICAgICAgPHN0b3Agc3RvcC1jb2xvcj0iIzQ2OCIgb2Zmc2V0PSIxIi8+CiAgICA8L2xpbmVhckdyYWRpZW50PgogIDwvZGVmcz4KCiAgPHBhdGggZD0iTTI3LDE2YzAtNyw5LTEzLDI0LTEzYzE1LDAsMjMsNiwyMywxM2wwLDIyYzAsNy01LDEyLTExLDEybC0yNCwwYy04LDAtMTQsNi0xNCwxNWwwLDEwbC05LDBjLTgsMC0xMy05LTEzLTI0YzAtMTQsNS0yMywxMy0yM2wzNSwwbDAtM2wtMjQsMGwwLTlsMCwweiBNODgsNTB2MSIgZmlsbD0idXJsKCNweUJsdWUpIi8+CiAgPHBhdGggZD0iTTc0LDg3YzAsNy04LDEzLTIzLDEzYy0xNSwwLTI0LTYtMjQtMTNsMC0yMmMwLTcsNi0xMiwxMi0xMmwyNCwwYzgsMCwxNC03LDE0LTE1bDAtMTBsOSwwYzcsMCwxMyw5LDEzLDIzYzAsMTUtNiwyNC0xMywyNGwtMzUsMGwwLDNsMjMsMGwwLDlsMCwweiBNMTQwLDUwdjEiIGZpbGw9InVybCgjcHlZZWxsb3cpIi8+CgogIDxjaXJjbGUgcj0iNCIgY3g9IjY0IiBjeT0iODgiIGZpbGw9IiNGRkYiLz4KICA8Y2lyY2xlIHI9IjQiIGN4PSIzNyIgY3k9IjE1IiBmaWxsPSIjRkZGIi8+Cjwvc3ZnPgo=)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)
![CI](https://github.com/luxonis/luxonis-train/actions/workflows/ci.yaml/badge.svg)
![Docs](https://github.com/luxonis/luxonis-train/actions/workflows/docs.yaml/badge.svg)
[![codecov](https://codecov.io/gh/luxonis/luxonis-train/graph/badge.svg?token=647MTHBYD5)](https://codecov.io/gh/luxonis/luxonis-train)
-Luxonis training framework (`luxonis-train`) is intended for training deep learning models that can run fast on OAK products.
+
-**The project is in a beta state and might be unstable or contain bugs - please report any feedback.**
+## π Overview
-## Table Of Contents
+`LuxonisTrain` is a user-friendly tool designed to streamline the training of deep learning models, especially for edge devices. Built on top of `PyTorch Lightning`, it simplifies the process of training, testing, and exporting models with minimal coding required.
-- [Installation](#installation)
-- [Training](#training)
-- [Customizations](#customizations)
-- [Tuning](#tuning)
-- [Exporting](#exporting)
-- [Credentials](#credentials)
-- [Contributing](#contributing)
+### β¨ Key Features
-## Installation
+- **No Coding Required**: Define your training pipeline entirely through a single `YAML` configuration file.
+- **Predefined Configurations**: Utilize ready-made configs for common computer vision tasks to start quickly.
+- **Customizable**: Extend functionality with custom components using an intuitive Python API.
+- **Edge Optimized**: Focus on models optimized for deployment on edge devices with limited compute resources.
-`luxonis-train` is hosted on PyPi and can be installed with `pip` as:
+> \[!WARNING\]
+> **The project is in a beta state and might be unstable or contain bugs - please report any feedback.**
+
+
+
+## π Quick Start
+
+Get started with `LuxonisTrain` in just a few steps:
+
+1. **Install `LuxonisTrain`**
+
+ ```bash
+ pip install luxonis-train
+ ```
+
+ This will create the `luxonis_train` executable in your `PATH`.
+
+1. **Use the provided `configs/detection_light_model.yaml` configuration file**
+
+ You can download the file by executing the following command:
+
+ ```bash
+ wget https://raw.githubusercontent.com/luxonis/luxonis-train/main/configs/detection_light_model.yaml
+ ```
+
+1. **Find a suitable dataset for your task**
+
+ We will use a sample COCO dataset from `RoboFlow` in this example.
+
+1. **Start training**
+
+ ```bash
+ luxonis_train train \
+ --config detection_light_model.yaml \
+ loader.params.dataset_dir "roboflow://team-roboflow/coco-128/2/coco"
+ ```
+
+1. **Monitor progress with `TensorBoard`**
+
+ ```bash
+ tensorboard --logdir output/tensorboard_logs
+ ```
+
+ Open the provided URL in your browser to visualize the training progress
+
+## π Table Of Contents
+
+- [π Overview](#overview)
+ - [β¨ Key Features](#key-features)
+- [π Quick Start](#quick-start)
+- [π οΈ Installation](#installation)
+- [π Usage](#usage)
+ - [π» CLI](#cli)
+- [βοΈ Configuration](#configuration)
+- [ποΈ Data Preparation](#data-preparation)
+ - [π Data Directory](#data-directory)
+ - [πΎ `LuxonisDataset`](#luxonis-dataset)
+- [ποΈββοΈTraining](#training)
+- [β Testing](#testing)
+- [π§ Inference](#inference)
+- [π€ Exporting](#exporting)
+- [ποΈ NN Archive](#nn-archive)
+- [π¬ Tuning](#tuning)
+- [π¨ Customizations](#customizations)
+- [π Tutorials and Examples](#tutorials-and-examples)
+- [π Credentials](#credentials)
+- [π€ Contributing](#contributing)
+
+
+
+## π οΈ Installation
+
+`LuxonisTrain` requires **Python 3.10** or higher. We recommend using a virtual environment to manage dependencies.
+
+**Install via `pip`**:
```bash
pip install luxonis-train
```
-This command will also create a `luxonis_train` executable in your `PATH`.
-See `luxonis_train --help` for more information.
+This will also install the `luxonis_train` CLI. For more information on how to use it, see [CLI Usage](#cli).
+
+
+
+## π Usage
-## Usage
+You can use `LuxonisTrain` either from the **command line** or via the **Python API**.
+We will demonstrate both ways in the following sections.
-The entire configuration is specified in a `yaml` file. This includes the model
-structure, used losses, metrics, optimizers etc. For specific instructions and example
-configuration files, see [Configuration](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md).
+
-### Data Preparation
+### π» CLI
-This library requires data to be in the Luxonis Dataset Format.
+The CLI is the most straightforward way how to use `LuxonisTrain`. The CLI provides several commands for training, testing, tuning, exporting and more.
-For instructions on how to create a dataset in the LDF, follow the
-[examples](https://github.com/luxonis/luxonis-ml/tree/main/examples) in
-the [luxonis-ml](https://github.com/luxonis/luxonis-ml) repository.
+**Available commands:**
-To inspect dataset images by split (train, val, test), use the command:
+- `train` - Start the training process
+- `test` - Test the model on a specific dataset view
+- `infer` - Run inference on a dataset, image directory, or a video file.
+- `export` - Export the model to either `ONNX` or `BLOB` format that can be run on edge devices
+- `archive` - Create an `NN Archive` file that can be used with our `DepthAI` API (coming soon)
+- `tune` - Tune the hyperparameters of the model for better performance
+- `inspect` - Inspect the dataset you are using and visualize the annotations
+
+**To get help on any command:**
```bash
-luxonis_train data inspect --config --view
+luxonis_train --help
```
-## Training
+Specific usage examples can be found in the respective sections below.
-Once you've created your `config.yaml` file you can train the model using this command:
+
-```bash
-luxonis_train train --config config.yaml
+## βοΈ Configuration
+
+`LuxonisTrain` uses `YAML` configuration files to define the training pipeline. Here's a breakdown of the key sections:
+
+```yaml
+model:
+ name: model_name
+
+ # Use a predefined detection model instead of defining
+ # the model architecture manually
+ predefined_model:
+ name: DetectionModel
+ params:
+ variant: light
+
+# Download and parse the coco dataset from RoboFlow.
+# Save it internally as `coco_test` dataset for future reference.
+loader:
+ params:
+ dataset_name: coco_test
+ dataset_dir: "roboflow://team-roboflow/coco-128/2/coco"
+
+trainer:
+ batch_size: 8
+ epochs: 200
+ n_workers: 8
+ validation_interval: 10
+
+ preprocessing:
+ train_image_size: [384, 384]
+
+ # Uses the imagenet normalization by default
+ normalize:
+ active: true
+
+ # Augmentations are powered by Albumentations
+ augmentations:
+ - name: Defocus
+ - name: Sharpen
+ - name: Flip
+
+ callbacks:
+ - name: ExportOnTrainEnd
+ - name: ArchiveOnTrainEnd
+ - name: TestOnTrainEnd
+
+ optimizer:
+ name: SGD
+ params:
+ lr: 0.02
+
+ scheduler:
+ name: ConstantLR
```
-If you wish to manually override some config parameters you can do this by providing the key-value pairs. Example of this is:
+For an extensive list of all the available options, see [Configuration](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md).
+
+We provide a set of predefined configuration files for the most common computer vision tasks.
+You can find them in the `configs` directory.
+
+
+
+## ποΈ Data Preparation
+
+`LuxonisTrain` supports several ways of loading data:
+
+- using a data directory in one of the supported formats
+- using an already existing dataset in our custom `LuxonisDataset` format
+- using a custom loader
+ - to learn how to implement and use custom loaders, see [Customizations](#customizations)
+
+
+
+### π Data Directory
+
+The easiest way to load data is to use a directory with the dataset in one of the supported formats.
+
+**Supported formats:**
+
+- `COCO` - We support COCO JSON format in two variants:
+ - [`RoboFlow`](https://roboflow.com/formats/coco-json)
+ - [`FiftyOne`](https://docs.voxel51.com/user_guide/export_datasets.html#cocodetectiondataset-export)
+- [`Pascal VOC XML`](https://roboflow.com/formats/pascal-voc-xml)
+- [`YOLO Darknet TXT`](https://roboflow.com/formats/yolo-darknet-txt)
+- [`YOLOv4 PyTorch TXT`](https://roboflow.com/formats/yolov4-pytorch-txt)
+- [`MT YOLOv6`](https://roboflow.com/formats/mt-yolov6)
+- [`CreateML JSON`](https://roboflow.com/formats/createml-json)
+- [`TensorFlow Object Detection CSV`](https://roboflow.com/formats/tensorflow-object-detection-csv)
+- `Classification Directory` - A directory with subdirectories for each class
+ ```plaintext
+ dataset_dir/
+ βββ train/
+ β βββ class1/
+ β β βββ img1.jpg
+ β β βββ img2.jpg
+ β β βββ ...
+ β βββ class2/
+ β βββ ...
+ βββ valid/
+ βββ test/
+ ```
+- `Segmentation Mask Directory` - A directory with images and corresponding masks.
+ ```plaintext
+ dataset_dir/
+ βββ train/
+ β βββ img1.jpg
+ β βββ img1_mask.png
+ β βββ ...
+ β βββ _classes.csv
+ βββ valid/
+ βββ test/
+ ```
+ The masks are stored as grayscale `PNG` images where each pixel value corresponds to a class.
+ The mapping from pixel values to classes is defined in the `_classes.csv` file.
+ ```csv
+ Pixel Value, Class
+ 0, background
+ 1, class1
+ 2, class2
+ 3, class3
+ ```
+
+#### Preparing your Data
+
+1. Organize your dataset into one of the supported formats.
+1. Place your dataset in a directory accessible by the training script.
+1. Update the `dataset_dir` parameter in the configuration file to point to the dataset directory.
+
+**The `dataset_dir` can be one of the following:**
+
+- Local path to the dataset directory
+- URL to a remote dataset
+ - The dataset will be downloaded to a `"data"` directory in the current working directory
+ - **Supported URL protocols:**
+ - `s3://bucket/path/to/directory` fo **AWS S3**
+ - `gs://buclet/path/to/directory` for **Google Cloud Storage**
+ - `roboflow://workspace/project/version/format` for **RoboFlow**
+ - `workspace` - name of the workspace the dataset belongs to
+ - `project` - name of the project the dataset belongs to
+ - `version` - version of the dataset
+ - `format` - one of `coco`, `darknet`, `voc`, `yolov4pytorch`, `mt-yolov6`, `createml`, `tensorflow`, `folder`, `png-mask-semantic`
+ - **example:** `roboflow://team-roboflow/coco-128/2/coco`
+
+**Example:**
+
+```yaml
+loader:
+ params:
+ dataset_name: "coco_test"
+ dataset_dir: "roboflow://team-roboflow/coco-128/2/coco"
+```
+
+
+
+### πΎ `LuxonisDataset`
+
+`LuxonisDataset` is our custom dataset format designed for easy and efficient dataset management.
+To learn more about how to create a dataset in this format from scratch, see the [Luxonis ML](https://github.com/luxonis/luxonis-ml) repository.
+
+To use the `LuxonisDataset` as a source of the data, specify the following in the config file:
+
+```yaml
+loader:
+ params:
+ # name of the dataset
+ dataset_name: "dataset_name"
+
+ # one of local (default), s3, gcs
+ bucket_storage: "local"
+```
+
+> \[!TIP\]
+> To inspect the loader output, use the `luxonis_train inspect` command:
+>
+> ```bash
+> luxonis_train inspect --config configs/detection_light_model.yaml
+> ```
+>
+> **The `inspect` command is currently only available in the CLI**
+
+
+
+## ποΈββοΈ Training
+
+Once your configuration file and dataset are ready, start the training process.
+
+**CLI:**
```bash
-luxonis_train train --config config.yaml trainer.batch_size 8 trainer.epochs 10
+luxonis_train train --config configs/detection_light_model.yaml
```
-where key and value are space separated and sub-keys are dot (`.`) separated. If the configuration field is a list, then key/sub-key should be a number (e.g. `trainer.preprocessing.augmentations.0.name RotateCustom`).
+> \[!TIP\]
+> To change a configuration parameter from the command line, use the following syntax:
+>
+> ```bash
+> luxonis_train train \
+> --config configs/detection_light_model.yaml \
+> loader.params.dataset_dir "roboflow://team-roboflow/coco-128/2/coco"
+> ```
-## Evaluating
+**Python API:**
-To evaluate the model on a specific dataset split (train, test, or val), use the following command:
+```python
+from luxonis_train import LuxonisModel
+
+model = LuxonisModel(
+ "configs/detection_light_model.yaml",
+ {"loader.params.dataset_dir": "roboflow://team-roboflow/coco-128/2/coco"}
+)
+model.train()
+```
+
+**Expected Output:**
+
+```log
+INFO Using predefined model: `DetectionModel`
+INFO Main metric: `MeanAveragePrecision`
+INFO GPU available: True (cuda), used: True
+INFO TPU available: False, using: 0 TPU cores
+INFO HPU available: False, using: 0 HPUs
+...
+INFO Training finished
+INFO Checkpoints saved in: output/1-coral-wren
+```
+
+**Monitoring with `TensorBoard`:**
+
+If not explicitly disabled, the training process will be monitored by `TensorBoard`. To start the `TensorBoard` server, run:
```bash
-luxonis_train eval --config --view
+tensorboard --logdir output/tensorboard_logs
```
-## Tuning
+Open the provided URL to visualize training metrics.
+
+
-To improve training performance you can use `Tuner` for hyperparameter optimization.
-To use tuning, you have to specify [tuner](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#tuner) section in the config file.
+## β Testing
-To start the tuning, run
+Evaluate your trained model on a specific dataset view (`train`, `val`, or `test`).
+
+**CLI:**
```bash
-luxonis_train tune --config config.yaml
+luxonis_train test --config configs/detection_light_model.yaml \
+ --view val \
+ --weights path/to/checkpoint.ckpt
```
-You can see an example tuning configuration [here](https://github.com/luxonis/luxonis-train/blob/main/configs/example_tuning.yaml).
+**Python API:**
+
+```python
+from luxonis_train import LuxonisModel
-## Exporting
+model = LuxonisModel("configs/detection_light_model.yaml")
+model.test(weights="path/to/checkpoint.ckpt")
+```
-We support export to `ONNX`, and `DepthAI .blob format` which is used for OAK cameras. By default, we export to `ONNX` format.
+The testing process can be started automatically at the end of the training by using the `TestOnTrainEnd` callback.
+To learn more about callbacks, see [Callbacks](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/callbacks/README.md).
-To use the exporter, you have to specify the [exporter](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#exporter) section in the config file.
+
-Once you have the config file ready you can export the model using
+## π§ Inference
+
+Run inference on images, datasets, or videos.
+
+**CLI:**
+
+- **Inference on a Dataset View:**
+
+```bash
+luxonis_train infer --config configs/detection_light_model.yaml \
+ --view val \
+ --weights path/to/checkpoint.ckpt
+```
+
+- **Inference on a Video File:**
+
+```bash
+luxonis_train infer --config configs/detection_light_model.yaml \
+ --weights path/to/checkpoint.ckpt \
+ --source-path path/to/video.mp4
+```
+
+- **Inference on an Image Directory:**
```bash
-luxonis_train export --config config.yaml
+luxonis_train infer --config configs/detection_light_model.yaml \
+ --weights path/to/checkpoint.ckpt \
+ --source-path path/to/images \
+ --save-dir path/to/save_directory
+```
+
+**Python API:**
+
+```python
+from luxonis_train import LuxonisModel
+
+model = LuxonisModel("configs/detection_light_model.yaml")
+
+# infer on a dataset view
+model.infer(weights="path/to/checkpoint.ckpt", view="val")
+
+# infer on a video file
+model.infer(weights="path/to/checkpoint.ckpt", source_path="path/to/video.mp4")
+
+# infer on an image directory and save the results
+model.infer(
+ weights="path/to/checkpoint.ckpt",
+ source_path="path/to/images",
+ save_dir="path/to/save_directory",
+)
```
+
+
+## π€ Exporting
+
+Export your trained models to formats suitable for deployment on edge devices.
+
+Supported formats:
+
+- **ONNX**: Open Neural Network Exchange format.
+- **BLOB**: Format compatible with OAK-D cameras.
+
+To configure the exporter, you can specify the [exporter](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#exporter) section in the config file.
+
You can see an example export configuration [here](https://github.com/luxonis/luxonis-train/blob/main/configs/example_export.yaml).
-## Customizations
+**CLI:**
+
+```bash
+luxonis_train export --config configs/example_export.yaml --weights path/to/weights.ckpt
+```
+
+**Python API:**
+
+```python
+from luxonis_train import LuxonisModel
+
+model = LuxonisModel("configs/example_export.yaml")
+model.export(weights="path/to/weights.ckpt")
+```
+
+Model export can be run automatically at the end of the training by using the `ExportOnTrainEnd` callback.
+
+The exported models are saved in the export directory within your `output` folder.
+
+
+
+## ποΈ NN Archive
+
+Create an `NN Archive` file for easy deployment with the `DepthAI` API.
+
+The archive contains the exported model together with all the metadata needed for running the model.
+
+**CLI:**
+
+```bash
+luxonis_train archive \
+ --config configs/detection_light_model.yaml \
+ --weights path/to/checkpoint.ckpt
+```
+
+**Python API:**
+
+```python
+from luxonis_train import LuxonisModel
+
+model = LuxonisModel("configs/detection_light_model.yaml")
+model.archive(weights="path/to/checkpoint.ckpt")
+```
+
+The archive can be created automatically at the end of the training by using the `ArchiveOnTrainEnd` callback.
+
+
+
+## π¬ Tuning
+
+Optimize your model's performance using hyperparameter tuning powered by [`Optuna`](https://optuna.org/).
+
+**Configuration:**
+
+Include a [`tuner`](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#tuner) section in your configuration file.
+
+```yaml
+
+tuner:
+ study_name: det_study
+ n_trials: 10
+ storage:
+ storage_type: local
+ params:
+ trainer.optimizer.name_categorical: ["Adam", "SGD"]
+ trainer.optimizer.params.lr_float: [0.0001, 0.001]
+ trainer.batch_size_int: [4, 16, 4]
+```
+
+**CLI:**
+
+```bash
+luxonis_train tune --config configs/example_tuning.yaml
+```
+
+**Python API:**
+
+```python
+from luxonis_train import LuxonisModel
+
+model = LuxonisModel("configs/example_tuning.yaml")
+model.tune()
+```
+
+
+
+## π¨ Customizations
+
+`LuxonisTrain` is highly modular, allowing you to customize various components:
+
+- **Loaders**: Handle data loading and preprocessing.
+- **Nodes**: Represent computational units in the model architecture.
+- **Losses**: Define the loss functions used to train the model.
+- **Metrics**: Measure the model's performance during training.
+- **Visualizers**: Visualize the model's predictions during training.
+- **Callbacks**: Allow custom code to be executed at different stages of training.
+- **Optimizers/Schedulers**: Control how the model's weights are updated.
+
+Understanding these components helps in tailoring the framework to your specific needs.
We provide a registry interface through which you can create new
-[nodes](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/nodes/README.md),
-[losses](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/losses/README.md),
-[metrics](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/metrics/README.md),
-[visualizers](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/visualizers/README.md),
-[callbacks](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/callbacks/README.md),
-[optimizers](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#optimizer),
-and [schedulers](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#scheduler).
-
-Registered components can be then referenced in the config file. Custom components need to inherit from their respective base classes:
-
-- Node - [BaseNode](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/models/nodes/base_node.py)
-- Loss - [BaseLoss](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/losses/base_loss.py)
-- Metric - [BaseMetric](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/metrics/base_metric.py)
-- Visualizer - [BaseVisualizer](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/visualizers/base_visualizer.py)
-- Callback - [Callback from lightning.pytorch.callbacks](lightning.pytorch.callbacks)
-- Optimizer - [Optimizer from torch.optim](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer)
-- Scheduler - [LRScheduler from torch.optim.lr_scheduler](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate)
-
-Here is an example of how to create custom components:
+
+- [**Loaders**](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/loaders/README.md): Handles data loading and preprocessing.
+- [**Nodes**](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/nodes/README.md): Represents computational units in the model architecture.
+- [**Losses**](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/losses/README.md): Define the loss functions used to train the model.
+- [**Metrics**](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/metrics/README.md): Measure the model's performance during training.
+- [**Visualizers**](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/visualizers/README.md): Visualize the model's predictions during training.
+- [**Callbacks**](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/callbacks/README.md): Allow custom code to be executed at different stages of training.
+- [**Optimizers**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#optimizer): Control how the model's weights are updated.
+- [**Schedulers**](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#scheduler): Adjust the learning rate during training.
+
+**Creating Custom Components:**
+
+Implement custom components by subclassing the respective base classes and/or registering them.
+Registered components can be referenced in the config file. Custom components need to inherit from their respective base classes:
+
+- **Loaders** - [`BaseLoader`](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/loaders/base_loader.py)
+- **Nodes** - [`BaseNode`](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/models/nodes/base_node.py)
+- **Losses** - [`BaseLoss`](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/losses/base_loss.py)
+- **Metrics** - [`BaseMetric`](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/metrics/base_metric.py)
+- **Visualizers** - [`BaseVisualizer`](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/attached_modules/visualizers/base_visualizer.py)
+- **Callbacks** - [`lightning.pytorch.callbacks.Callback`](https://lightning.ai/docs/pytorch/stable/extensions/callbacks.html), requires manual registration to the `CALLBACKS` registry
+- **Optimizers** - [`torch.optim.Optimizer`](https://pytorch.org/docs/stable/optim.html#torch.optim.Optimizer), requires manual registration to the `OPTIMIZERS` registry
+- **Schedulers** - [`torch.optim.lr_scheduler.LRScheduler`](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate), requires manual registration to the `SCHEDULERS` registry
+
+**Example:**
```python
from torch.optim import Optimizer
@@ -137,82 +604,103 @@ from luxonis_train.attached_modules.losses import BaseLoss
@OPTIMIZERS.register_module()
class CustomOptimizer(Optimizer):
- ...
+ def __init__(self, params, lr=0.001):
+ super().__init__(params, defaults={'lr': lr})
+ # Implement optimizer logic
-# Subclasses of BaseNode, LuxonisLoss, LuxonisMetric
+# Subclasses of BaseNode, BaseLoss, BaseMetric
# and BaseVisualizer are registered automatically.
-
class CustomLoss(BaseLoss):
- # This class is automatically registered under `CustomLoss` name.
+ # This class is automatically registered under the name `CustomLoss`.
def __init__(self, k_steps: int, **kwargs):
super().__init__(**kwargs)
...
```
-And then in the config you reference this `CustomOptimizer` and `CustomLoss` by their names:
+**Using custom components in config:**
```yaml
-losses:
- - name: CustomLoss
- params: # additional parameters
- k_steps: 12
-
+model:
+ nodes:
+ - name: SegmentationHead
+ losses:
+ - name: CustomLoss
+ params:
+ k_steps: 12
+
+optimizer:
+ name: CustomOptimizer
+ params:
+ lr: 0.01
```
-For more information on how to define custom components, consult the respective in-source documentation.
+> \[!NOTE\]
+> Files containing the custom components must be sourced before the training script is run.
+> To do that in CLI, you can use the `--source` argument:
+>
+> ```bash
+> luxonis_train --source custom_components.py train --config config.yaml
+> ```
-## Credentials
+**Python API:**
-Local use is supported by default. In addition, we also integrate some cloud services which can be primarily used for logging and storing. When these are used, you need to load environment variables to set up the correct credentials.
+You have to import the custom components before creating the `LuxonisModel` instance.
-You have these options how to set up the environment variables:
+```python
+from custom_components import *
+from luxonis_train import LuxonisModel
-- Using standard environment variables
-- Specifying the variables in a `.env` file. If a variable is both in the environment and present in `.env` file, the exported variable takes precedence.
-- Specifying the variables in the [ENVIRON](https://github.com/luxonis/luxonis-train/blob/main/configs/README.md#environ) section of the config file. Note that this is not a recommended way. Variables defined in config take precedence over environment and `.env` variables.
+model = LuxonisModel("config.yaml")
+model.train()
+```
-### S3
+For more information on how to define custom components, consult the respective in-source documentation.
-If you are working with LuxonisDataset that is hosted on S3, you need to specify these env variables:
+
-```bash
-AWS_ACCESS_KEY_ID=**********
-AWS_SECRET_ACCESS_KEY=**********
-AWS_S3_ENDPOINT_URL=**********
-```
+## π Tutorials and Examples
-### MLFlow
+We are actively working on providing examples and tutorials for different parts of the library which will help you to start more easily. The tutorials can be found [here](https://github.com/luxonis/depthai-ml-training/tree/master) and will be updated regularly.
-If you want to use MLFlow for logging and storing artifacts you also need to specify MLFlow-related env variables like this:
+
-```bash
-MLFLOW_S3_BUCKET=**********
-MLFLOW_S3_ENDPOINT_URL=**********
-MLFLOW_TRACKING_URI=**********
-```
+## π Credentials
-### WandB
+When using cloud services, avoid hard-coding credentials or placing them directly in your configuration files.
+Instead:
-If you are using WandB for logging, you have to sign in first in your environment.
+- Use environment variables to store sensitive information.
+- Use a `.env` file and load it securely, ensuring it's excluded from version control.
-### POSTGRESS
+**Supported Cloud Services:**
-There is an option for remote storage for [Tuning](#tuning). We use POSTGRES and to connect to the database you need to specify the following env variables:
+- **AWS S3**, requires:
+ - `AWS_ACCESS_KEY_ID`
+ - `AWS_SECRET_ACCESS_KEY`
+ - `AWS_S3_ENDPOINT_URL`
+- **Google Cloud Storage**, requires:
+ - `GOOGLE_APPLICATION_CREDENTIALS`
+- **RoboFlow**, requires:
+ - `ROBOFLOW_API_KEY`
-```bash
-POSTGRES_USER=**********
-POSTGRES_PASSWORD=**********
-POSTGRES_HOST=**********
-POSTGRES_PORT=**********
-POSTGRES_DB=**********
-```
+**For logging and tracking, we support:**
-## Contributing
+- **MLFlow**, requires:
+ - `MLFLOW_S3_BUCKET`
+ - `MLFLOW_S3_ENDPOINT_URL`
+ - `MLFLOW_TRACKING_URI`
+- **WandB**, requires:
+ - `WANDB_API_KEY`
-If you want to contribute to the development, install the dev version of the package:
+**For remote database storage, we support:**
-```bash
-pip install luxonis-train[dev]
-```
+- `POSTGRES_PASSWORD`
+- `POSTGRES_HOST`
+- `POSTGRES_PORT`
+- `POSTGRES_DB`
+
+
+
+## π€ Contributing
-Consult the [Contribution guide](https://github.com/luxonis/luxonis-train/blob/main/CONTRIBUTING.md) for further instructions.
+We welcome contributions! Please read our [Contribution Guide](https://github.com/luxonis/luxonis-train/blob/main/CONTRIBUTING.md) to get started. Whether it's reporting bugs, improving documentation, or adding new features, your help is appreciated.
diff --git a/configs/README.md b/configs/README.md
index 384f6220..b06c9495 100644
--- a/configs/README.md
+++ b/configs/README.md
@@ -1,6 +1,6 @@
# Configuration
-The configuration is defined in a yaml file, which you must provide.
+The configuration is defined in a `YAML` file, which you must provide.
The configuration file consists of a few major blocks that are described below.
You can create your own config or use/edit one of the examples.
@@ -9,19 +9,20 @@ You can create your own config or use/edit one of the examples.
- [Top-level Options](#top-level-options)
- [Model](#model)
- [Nodes](#nodes)
- - [Attached Modules](#attached-modules)
- [Losses](#losses)
- [Metrics](#metrics)
- [Visualizers](#visualizers)
- [Tracker](#tracker)
- [Loader](#loader)
-- [Trainer](#train)
+ - [`LuxonisLoaderTorch`](#luxonisloadertorch)
+- [Trainer](#trainer)
- [Preprocessing](#preprocessing)
+ - [Augmentations](#augmentations)
+ - [Callbacks](#callbacks)
- [Optimizer](#optimizer)
- [Scheduler](#scheduler)
- - [Callbacks](#callbacks)
- [Exporter](#exporter)
- - [ONNX](#onnx)
+ - [`ONNX`](#onnx)
- [Blob](#blob)
- [Tuner](#tuner)
- [Storage](#storage)
@@ -29,288 +30,462 @@ You can create your own config or use/edit one of the examples.
## Top-level Options
-| Key | Type | Default value | Description |
-| -------- | --------------------- | ------------- | ---------------- |
-| model | [Model](#model) | | Model section |
-| loader | [loader](#loader) | | Loader section |
-| train | [train](#train) | | Train section |
-| tracker | [tracker](#tracker) | | Tracker section |
-| trainer | [trainer](#trainer) | | Trainer section |
-| exporter | [exporter](#exporter) | | Exporter section |
-| tuner | [tuner](#tuner) | | Tuner section |
+| Key | Type | Description |
+| ---------- | ----------------------- | ---------------- |
+| `model` | [`model`](#model) | Model section |
+| `loader` | [`loader`](#loader) | Loader section |
+| `train` | [`train`](#train) | Train section |
+| `tracker` | [`tracker`](#tracker) | Tracker section |
+| `trainer` | [`trainer`](#trainer) | Trainer section |
+| `exporter` | [`exporter`](#exporter) | Exporter section |
+| `tuner` | [`tuner`](#tuner) | Tuner section |
## Model
This is the most important block, that **must be always defined by the user**. There are two different ways you can create the model.
-| Key | Type | Default value | Description |
-| ---------------- | ---- | ------------- | ---------------------------------------------------------- |
-| name | str | "model" | Name of the model |
-| weights | path | None | Path to weights to load |
-| predefined_model | str | None | Name of a predefined model to use |
-| params | dict | {} | Parameters for the predefined model |
-| nodes | list | \[\] | List of nodes (see [nodes](#nodes) |
-| losses | list | \[\] | lList of losses (see [losses](#losses) |
-| metrics | list | \[\] | List of metrics (see [metrics](#metrics) |
-| visualziers | list | \[\] | List of visualizers (see [visualizers](#visualizers) |
-| outputs | list | \[\] | List of outputs nodes, inferred from nodes if not provided |
+| Key | Type | Default value | Description |
+| ------------------ | ------ | ------------- | ---------------------------------------------------------- |
+| `name` | `str` | `"model"` | Name of the model |
+| `weights` | `path` | `None` | Path to weights to load |
+| `predefined_model` | `str` | `None` | Name of a predefined model to use |
+| `params` | `dict` | `{}` | Parameters for the predefined model |
+| `nodes` | `list` | `[]` | List of nodes (see [nodes](#nodes)) |
+| `losses` | `list` | `[]` | List of losses (see [losses](#losses)) |
+| `metrics` | `list` | `[]` | List of metrics (see [metrics](#metrics)) |
+| `visualziers` | `list` | `[]` | List of visualizers (see [visualizers](#visualizers)) |
+| `outputs` | `list` | `[]` | List of outputs nodes, inferred from nodes if not provided |
### Nodes
For list of all nodes, see [nodes](../luxonis_train/nodes/README.md).
-| Key | Type | Default value | Description |
-| ----------------------- | -------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
-| name | str | | Name of the node |
-| alias | str | None | Custom name for the node |
-| params | dict | {} | Parameters for the node |
-| inputs | list | \[\] | List of input nodes for this node, if empty, the node is understood to be an input node of the model |
-| freezing.active | bool | False | whether to freeze the modules so the weights are not updated |
-| freezing.unfreeze_after | int \| float \| None | None | After how many epochs should the modules be unfrozen, can be `int` for a specific number of epochs or `float` for a portion of the training |
-| remove_on_export | bool | False | Whether the node should be removed when exporting |
-
-### Attached Modules
-
-Modules that are attached to a node. This include losses, metrics and visualziers.
-
-| Key | Type | Default value | Description |
-| ----------- | ---- | ------------- | ------------------------------------------- |
-| name | str | | Name of the module |
-| attached_to | str | | Name of the node the module is attached to. |
-| alias | str | None | Custom name for the module |
-| params | dict | {} | Parameters of the module |
+| Key | Type | Default value | Description |
+| ------------------------- | ---------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------- |
+| `name` | `str` | - | Name of the node |
+| `alias` | `str` | `None` | Custom name for the node |
+| `params` | `dict` | `{}` | Parameters for the node |
+| `inputs` | `list` | `[]` | List of input nodes for this node, if empty, the node is understood to be an input node of the model |
+| `freezing.active` | `bool` | `False` | whether to freeze the modules so the weights are not updated |
+| `freezing.unfreeze_after` | `int \| float \| None` | `None` | After how many epochs should the modules be unfrozen, can be `int` for a specific number of epochs or `float` for a portion of the training |
+| `remove_on_export` | `bool` | `False` | Whether the node should be removed when exporting |
+| `losses` | `list` | `[]` | List of losses attached to this node |
+| `metrics` | `list` | `[]` | List of metrics attached to this node |
+| `visualizers` | `list` | `[]` | List of visualizers attached to this node |
#### Losses
At least one node must have a loss attached to it.
You can see the list of all currently supported loss functions and their parameters [here](../luxonis_train/attached_modules/losses/README.md).
-| Key | Type | Default value | Description |
-| ------ | ----- | ------------- | ---------------------------------------- |
-| weight | float | 1.0 | Weight of the loss used in the final sum |
+| Key | Type | Default value | Description |
+| -------- | ------- | ------------- | ---------------------------------------- |
+| `weight` | `float` | `1.0` | Weight of the loss used in the final sum |
+| `alias` | `str` | `None` | Custom name for the loss |
+| `params` | `dict` | `{}` | Additional parameters for the loss |
#### Metrics
In this section, you configure which metrics should be used for which node.
You can see the list of all currently supported metrics and their parameters [here](../luxonis_train/attached_modules/metrics/README.md).
-| Key | Type | Default value | Description |
-| -------------- | ---- | ------------- | --------------------------------------------------------------------------------------- |
-| is_main_metric | bool | False | Marks this specific metric as the main one. Main metric is used for saving checkpoints. |
+| Key | Type | Default value | Description |
+| ---------------- | ------ | ------------- | -------------------------------------------------------------------------------------- |
+| `is_main_metric` | `bool` | `False` | Marks this specific metric as the main one. Main metric is used for saving checkpoints |
+| `alias` | `str` | `None` | Custom name for the metric |
+| `params` | `dict` | `{}` | Additional parameters for the metric |
#### Visualizers
In this section, you configure which visualizers should be used for which node. Visualizers are responsible for creating images during training.
You can see the list of all currently supported visualizers and their parameters [here](../luxonis_train/attached_modules/visualizers/README.md).
-Visualizers have no specific configuration.
+| Key | Type | Default value | Description |
+| -------- | ------ | ------------- | ---------------------------------------- |
+| `alias` | `str` | `None` | Custom name for the visualizer |
+| `params` | `dict` | `{}` | Additional parameters for the visualizer |
+
+**Example:**
+
+```yaml
+name: "SegmentationHead"
+inputs:
+ - "RepPANNeck"
+losses:
+ - name: "BCEWithLogitsLoss"
+metrics:
+ - name: "F1Score"
+ params:
+ task: "binary"
+ - name: "JaccardIndex"
+ params:
+ task: "binary"
+visualizers:
+ - name: "SegmentationVisualizer"
+ params:
+ colors: "#FF5055"
+```
## Tracker
-This library uses [LuxonisTrackerPL](https://github.com/luxonis/luxonis-ml/blob/b2399335efa914ef142b1b1a5db52ad90985c539/src/luxonis_ml/ops/tracker.py#L152).
+This library uses [`LuxonisTrackerPL`](https://github.com/luxonis/luxonis-ml/blob/b2399335efa914ef142b1b1a5db52ad90985c539/src/luxonis_ml/ops/tracker.py#L152).
You can configure it like this:
-| Key | Type | Default value | Description |
-| -------------- | ----------- | ------------- | ---------------------------------------------------------- |
-| project_name | str \| None | None | Name of the project used for logging. |
-| project_id | str \| None | None | Id of the project used for logging (relevant for MLFlow). |
-| run_name | str \| None | None | Name of the run. If empty, then it will be auto-generated. |
-| run_id | str \| None | None | Id of an already created run (relevant for MLFLow.) |
-| save_directory | str | "output" | Path to the save directory. |
-| is_tensorboard | bool | True | Whether to use tensorboard. |
-| is_wandb | bool | False | Whether to use WandB. |
-| wandb_entity | str \| None | None | Name of WandB entity. |
-| is_mlflow | bool | False | Whether to use MLFlow. |
+| Key | Type | Default value | Description |
+| ---------------- | ------------- | ------------- | ---------------------------------------------------------- |
+| `project_name` | `str \| None` | `None` | Name of the project used for logging |
+| `project_id` | `str \| None` | `None` | ID of the project used for logging (relevant for `MLFlow`) |
+| `run_name` | `str \| None` | `None` | Name of the run. If empty, then it will be auto-generated |
+| `run_id` | `str \| None` | `None` | ID of an already created run (relevant for `MLFLow`) |
+| `save_directory` | `str` | `"output"` | Path to the save directory |
+| `is_tensorboard` | `bool` | `True` | Whether to use `Tensorboard` |
+| `is_wandb` | `bool` | `False` | Whether to use `WandB` |
+| `wandb_entity` | `str \| None` | `None` | Name of `WandB` entity |
+| `is_mlflow` | `bool` | `False` | Whether to use `MLFlow` |
+
+**Example:**
+
+```yaml
+tracker:
+ project_name: "project_name"
+ save_directory: "output"
+ is_tensorboard: true
+ is_wandb: false
+ is_mlflow: false
+```
## Loader
This section controls the data loading process and parameters regarding the dataset.
-To store and load the data we use LuxonisDataset and LuxonisLoader. For specific config parameters refer to [LuxonisML](https://github.com/luxonis/luxonis-ml).
+To store and load the data we use `LuxonisDataset` and `LuxonisLoader.` For specific config parameters refer to [`LuxonisML`](https://github.com/luxonis/luxonis-ml).
-| Key | Type | Default value | Description |
-| ------------ | ------------------ | ------------------ | -------------------------------- |
-| name | str | LuxonisLoaderTorch | Name of the Loader |
-| image_source | str | image | Name of the input image group |
-| train_view | str \| list\[str\] | train | splits to use for training |
-| val_view | str \| list\[str\] | val | splits to use for validation |
-| test_view | str \| list\[str\] | test | splits to use for testing |
-| params | Dict\[str, Any\] | {} | Additional parameters for loader |
+| Key | Type | Default value | Description |
+| -------------- | ------------------ | ---------------------- | ------------------------------------ |
+| `name` | `str` | `"LuxonisLoaderTorch"` | Name of the Loader |
+| `image_source` | `str` | `"image"` | Name of the input image group |
+| `train_view` | `str \| list[str]` | `"train"` | splits to use for training |
+| `val_view` | `str \| list[str]` | `"val"` | splits to use for validation |
+| `test_view` | `str \| list[str]` | `"test"` | splits to use for testing |
+| `params` | `dict[str, Any]` | `{}` | Additional parameters for the loader |
-### LuxonisLoaderTorch
+### `LuxonisLoaderTorch`
-By default LuxonisLoaderTorch which can either use an existing LuxonisDataset or create a new one if it can be parsed automatically by LuxonisParser (check [LuxonisML](https://github.com/luxonis/luxonis-ml) `data` subpackage for more info).
+By default, `LuxonisLoaderTorch` can either use an existing `LuxonisDataset` or create a new one if it can be parsed automatically by `LuxonisParser` (check [`LuxonisML`](https://github.com/luxonis/luxonis-ml) `data` sub-package for more info).
-In most cases you want to change one of the parameters below. You can check all the parameters in `LuxonisLoaderTorch` class itself.
+In most cases you want to set one of the parameters below. You can check all the parameters in the `LuxonisLoaderTorch` class itself.
-| dataset_name | str | None | None | Name of an existing LuxonisDataset. |
-| dataset_dir | str | None | None | Location of the data from which new LuxonisDataset will be created |
-| dataset_type | DatasetType | None | None | Can specify exact format of the data. If None and new dataset needs to be created then it will be infered automatically. |
+| Key | Type | Default value | Description |
+| -------------- | ----- | ------------- | -------------------------------------------------------------------- |
+| `dataset_name` | `str` | `None` | Name of an existing `LuxonisDataset` |
+| `dataset_dir` | `str` | `None` | Location of the data from which new `LuxonisDataset` will be created |
+
+**Example:**
+
+```yaml
+loader:
+ # using default loader with an existing dataset
+ params:
+ dataset_name: "dataset_name"
+```
+
+```yaml
+loader:
+ # using default loader with a directory
+ params:
+ dataset_name: "dataset_name"
+ dataset_dir: "path/to/dataset"
+```
## Trainer
Here you can change everything related to actual training of the model.
-| Key | Type | Default value | Description |
-| ----------------------- | ---------------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------ |
-| seed | int | None | Seed for reproducibility |
-| deterministic | bool \| "warn" \| None | None | Whether pytorch should use deterministic backend |
-| batch_size | int | 32 | Batch size used for training |
-| accumulate_grad_batches | int | 1 | Number of batches for gradient accumulation |
-| use_weighted_sampler | bool | False | Bool if use WeightedRandomSampler for training, only works with classification tasks |
-| epochs | int | 100 | Number of training epochs |
-| n_workers | int | 4 | Number of workers for data loading |
-| validation_interval | int | 5 | Frequency of computing metrics on validation data |
-| n_log_images | int | 4 | Maximum number of images to visualize and log |
-| skip_last_batch | bool | True | Whether to skip last batch while training |
-| accelerator | Literal\["auto", "cpu", "gpu"\] | "auto" | What accelerator to use for training. |
-| devices | int \| list\[int\] \| str | "auto" | Either specify how many devices to use (int), list specific devices, or use "auto" for automatic configuration based on the selected accelerator |
-| matmul_precision | Literal\["medium", "high", "highest"\] \| None | None | Sets the internal precision of float32 matrix multiplications. |
-| strategy | Literal\["auto", "ddp"\] | "auto" | What strategy to use for training. |
-| n_sanity_val_steps | int | 2 | Number of sanity validation steps performed before training. |
-| profiler | Literal\["simple", "advanced"\] \| None | None | PL profiler for GPU/CPU/RAM utilization analysis |
-| verbose | bool | True | Print all intermediate results to console. |
-| pin_memory | bool | True | Whether to pin memory in the DataLoader |
-| save_top_k | -1 \| NonNegativeInt | 3 | Save top K checkpoints based on validation loss when training. |
+| Key | Type | Default value | Description |
+| ------------------------- | ---------------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------ |
+| `seed` | `int` | `None` | Seed for reproducibility |
+| `deterministic` | `bool \| "warn" \| None` | `None` | Whether PyTorch should use deterministic backend |
+| `batch_size` | `int` | `32` | Batch size used for training |
+| `accumulate_grad_batches` | `int` | `1` | Number of batches for gradient accumulation |
+| `use_weighted_sampler` | `bool` | `False` | Whether to use `WeightedRandomSampler` for training, only works with classification tasks |
+| `epochs` | `int` | `100` | Number of training epochs |
+| `n_workers` | `int` | `4` | Number of workers for data loading |
+| `validation_interval` | `int` | `5` | Frequency of computing metrics on validation data |
+| `n_log_images` | `int` | `4` | Maximum number of images to visualize and log |
+| `skip_last_batch` | `bool` | `True` | Whether to skip last batch while training |
+| `accelerator` | `Literal["auto", "cpu", "gpu"]` | `"auto"` | What accelerator to use for training |
+| `devices` | `int \| list[int] \| str` | `"auto"` | Either specify how many devices to use (int), list specific devices, or use "auto" for automatic configuration based on the selected accelerator |
+| `matmul_precision` | `Literal["medium", "high", "highest"] \| None` | `None` | Sets the internal precision of float32 matrix multiplications |
+| `strategy` | `Literal["auto", "ddp"]` | `"auto"` | What strategy to use for training |
+| `n_sanity_val_steps` | `int` | `2` | Number of sanity validation steps performed before training |
+| `profiler` | `Literal["simple", "advanced"] \| None` | `None` | PL profiler for GPU/CPU/RAM utilization analysis |
+| `verbose` | `bool` | `True` | Print all intermediate results to console |
+| `pin_memory` | `bool` | `True` | Whether to pin memory in the `DataLoader` |
+| `save_top_k` | `-1 \| NonNegativeInt` | `3` | Save top K checkpoints based on validation loss when training |
+
+**Example:**
+
+```yaml
+
+trainer:
+ accelerator: "auto"
+ devices: "auto"
+ strategy: "auto"
+
+ n_sanity_val_steps: 1
+ profiler: null
+ verbose: true
+ batch_size: 8
+ accumulate_grad_batches: 1
+ epochs: 200
+ n_workers: 8
+ validation_interval: 10
+ n_log_images: 8
+ skip_last_batch: true
+ log_sub_losses: true
+ save_top_k: 3
+```
### Preprocessing
-We use [Albumentations](https://albumentations.ai/docs/) library for `augmentations`. [Here](https://albumentations.ai/docs/api_reference/full_reference/#pixel-level-transforms) you can see a list of all pixel level augmentations supported, and [here](https://albumentations.ai/docs/api_reference/full_reference/#spatial-level-transforms) you see all spatial level transformations. In config you can specify any augmentation from this lists and their params.
+We use [`Albumentations`](https://albumentations.ai/docs/) library for `augmentations`. [Here](https://albumentations.ai/docs/api_reference/full_reference/#pixel-level-transforms) you can see a list of all pixel level augmentations supported, and [here](https://albumentations.ai/docs/api_reference/full_reference/#spatial-level-transforms) you see all spatial level transformations. In the configuration you can specify any augmentation from these lists and their parameters.
+
+Additionally, we support `Mosaic4` and `MixUp` batch augmentations and letterbox resizing if `keep_aspect_ratio: true`.
+
+| Key | Type | Default value | Description |
+| ------------------- | ------------ | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `train_image_size` | `list[int]` | `[256, 256]` | Image size used for training as `[height, width]` |
+| `keep_aspect_ratio` | `bool` | `True` | Whether to keep the aspect ratio while resizing |
+| `train_rgb` | `bool` | `True` | Whether to train on RGB or BGR images |
+| `normalize.active` | `bool` | `True` | Whether to use normalization |
+| `normalize.params` | `dict` | `{}` | Parameters for normalization, see [Normalize](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize) |
+| `augmentations` | `list[dict]` | `[]` | List of `Albumentations` augmentations |
+
+#### Augmentations
+
+| Key | Type | Default value | Description |
+| -------- | ------ | ------------- | ---------------------------------- |
+| `name` | `str` | - | Name of the augmentation |
+| `active` | `bool` | `True` | Whether the augmentation is active |
+| `params` | `dict` | `{}` | Parameters of the augmentation |
+
+**Example:**
+
+```yaml
+
+trainer:
+ preprocessing:
+ # using YAML capture to reuse the image size
+ train_image_size: [&height 384, &width 384]
+ keep_aspect_ratio: true
+ train_rgb: true
+ normalize:
+ active: true
+ augmentations:
+ - name: "Defocus"
+ params:
+ p: 0.1
+ - name: "Sharpen"
+ params:
+ p: 0.1
+ - name: "Flip"
+ - name: "RandomRotate90"
+ - name: "Mosaic4"
+ params:
+ out_width: *width
+ out_height: *height
-Additionaly we support `Mosaic4` and `MixUp` batch augmentations and letterbox resizing if `keep_aspect_ratio: True`.
+```
-| Key | Type | Default value | Description |
-| ----------------- | --------------------------------------------------------------------------------------------------------------------------------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| train_image_size | list\[int\] | \[256, 256\] | Image size used for training \[height, width\] |
-| keep_aspect_ratio | bool | True | Bool if keep aspect ration while resizing |
-| train_rgb | bool | True | Bool if train on rgb or bgr |
-| normalize.active | bool | True | Bool if use normalization |
-| normalize.params | dict | {} | Params for normalization, see [documentation](https://albumentations.ai/docs/api_reference/augmentations/transforms/#albumentations.augmentations.transforms.Normalize) |
-| augmentations | list\[{"name": Name of the augmentation, "active": Bool if aug is active, by default set to True, "params": Parameters of the augmentation}\] | \[\] | List of Albumentations augmentations |
+### Callbacks
+
+Callbacks sections contain a list of callbacks.
+More information on callbacks and a list of available ones can be found [here](../luxonis_train/callbacks/README.md)
+Each callback is a dictionary with the following fields:
+
+| Key | Type | Default value | Description |
+| -------- | ------ | ------------- | -------------------------- |
+| `name` | `str` | - | Name of the callback |
+| `active` | `bool` | `True` | Whether callback is active |
+| `params` | `dict` | `{}` | Parameters of the callback |
+
+**Example:**
+
+```yaml
+
+trainer:
+ callbacks:
+ - name: "LearningRateMonitor"
+ params:
+ logging_interval: "step"
+ - name: MetadataLogger
+ params:
+ hyperparams: ["trainer.epochs", "trainer.batch_size"]
+ - name: "EarlyStopping"
+ params:
+ patience: 3
+ monitor: "val/loss"
+ mode: "min"
+ verbose: true
+ - name: "ExportOnTrainEnd"
+ - name: "TestOnTrainEnd"
+```
### Optimizer
What optimizer to use for training.
List of all optimizers can be found [here](https://pytorch.org/docs/stable/optim.html).
-| Key | Type | Default value | Description |
-| ------ | ---- | ------------- | ---------------------------- |
-| name | str | "Adam" | Name of the optimizer. |
-| params | dict | {} | Parameters of the optimizer. |
+| Key | Type | Default value | Description |
+| -------- | ------ | ------------- | --------------------------- |
+| `name` | `str` | `"Adam"` | Name of the optimizer |
+| `params` | `dict` | `{}` | Parameters of the optimizer |
+
+**Example:**
+
+```yaml
+optimizer:
+ name: "SGD"
+ params:
+ lr: 0.02
+ momentum: 0.937
+ nesterov: true
+ weight_decay: 0.0005
+```
### Scheduler
What scheduler to use for training.
List of all optimizers can be found [here](https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate).
-| Key | Type | Default value | Description |
-| ------ | ---- | ------------- | ---------------------------- |
-| name | str | "ConstantLR" | Name of the scheduler. |
-| params | dict | {} | Parameters of the scheduler. |
+| Key | Type | Default value | Description |
+| -------- | ------ | -------------- | --------------------------- |
+| `name` | `str` | `"ConstantLR"` | Name of the scheduler |
+| `params` | `dict` | `{}` | Parameters of the scheduler |
-### Callbacks
-
-Callbacks sections contains a list of callbacks.
-More information on callbacks and a list of available ones can be found [here](../luxonis_train/callbacks/README.md)
-Each callback is a dictionary with the following fields:
+**Example:**
-| Key | Type | Default value | Description |
-| ------ | ---- | ------------- | --------------------------- |
-| name | str | | Name of the callback. |
-| active | bool | True | Whether calback is active. |
-| params | dict | {} | Parameters of the callback. |
+```yaml
+trainer:
+ scheduler:
+ name: "CosineAnnealingLR"
+ params:
+ T_max: *epochs
+ eta_min: 0
+```
## Exporter
Here you can define configuration for exporting.
-| Key | Type | Default value | Description |
-| ---------------------- | --------------------------------- | ------------- | ----------------------------------------------------------------------------------------------- |
-| name | str \| None | None | Name of the exported model. |
-| input_shape | list\[int\] \| None | None | Input shape of the model. If not provided, inferred from the dataset. |
-| data_type | Literal\["INT8", "FP16", "FP32"\] | "FP16" | Data type of the exported model. Only used for conversion to BLOB. |
-| reverse_input_channels | bool | True | Whether to reverse the image channels in the exported model. Relevant for `.blob` export |
-| scale_values | list\[float\] \| None | None | What scale values to use for input normalization. If not provided, inferred from augmentations. |
-| mean_values | list\[float\] \| None | None | What mean values to use for input normalizations. If not provided, inferred from augmentations. |
-| upload_to_run | bool | True | Whether to upload the exported files to tracked run as artifact. |
-| upload_url | str \| None | None | Exported model will be uploaded to this url if specified. |
+| Key | Type | Default value | Description |
+| ------------------------ | --------------------------------- | ------------- | ---------------------------------------------------------------------------------------------- |
+| `name` | `str \| None` | `None` | Name of the exported model |
+| `input_shape` | `list\[int\] \| None` | `None` | Input shape of the model. If not provided, inferred from the dataset |
+| `data_type` | `Literal["INT8", "FP16", "FP32"]` | `"FP16"` | Data type of the exported model. Only used for conversion to BLOB |
+| `reverse_input_channels` | `bool` | `True` | Whether to reverse the image channels in the exported model. Relevant for `BLOB` export |
+| `scale_values` | `list[float] \| None` | `None` | What scale values to use for input normalization. If not provided, inferred from augmentations |
+| `mean_values` | `list[float] \| None` | `None` | What mean values to use for input normalization. If not provided, inferred from augmentations |
+| `upload_to_run` | `bool` | `True` | Whether to upload the exported files to tracked run as artifact |
+| `upload_url` | `str \| None` | `None` | Exported model will be uploaded to this URL if specified |
+| `output_names` | `list[str] \| None` | `None` | Optional list of output names to override the default ones |
-### ONNX
+### `ONNX`
-Option specific for ONNX export.
+Option specific for `ONNX` export.
-| Key | Type | Default value | Description |
-| ------------- | ------------------------ | ------------- | -------------------------------- |
-| opset_version | int | 12 | Which opset version to use. |
-| dynamic_axes | dict\[str, Any\] \| None | None | Whether to specify dinamic axes. |
+| Key | Type | Default value | Description |
+| --------------- | ------------------------ | ------------- | --------------------------------- |
+| `opset_version` | `int` | `12` | Which `ONNX` opset version to use |
+| `dynamic_axes` | `dict[str, Any] \| None` | `None` | Whether to specify dynamic axes |
### Blob
-| Key | Type | Default value | Description |
-| ------- | ---------------------------------------------------------------- | ------------- | --------------------------------------- |
-| active | bool | False | Whether to export to `.blob` format. |
-| shaves | int | 6 | How many shaves. |
-| version | Literal\["2021.2", "2021.3", "2021.4", "2022.1", "2022.3_RVC3"\] | "2022.1" | OpenVINO version to use for conversion. |
+| Key | Type | Default value | Description |
+| --------- | ---------------------------------------------------------------- | ------------- | ---------------------------------------- |
+| `active` | `bool` | `False` | Whether to export to `BLOB` format |
+| `shaves` | `int` | `6` | How many shaves |
+| `version` | `Literal["2021.2", "2021.3", "2021.4", "2022.1", "2022.3_RVC3"]` | `"2022.1"` | `OpenVINO` version to use for conversion |
+
+**Example:**
+
+```yaml
+exporter:
+ output_names: ["output1", "output2"]
+ onnx:
+ opset_version: 11
+ blobconverter:
+ active: true
+ shaves: 8
+```
## Tuner
Here you can specify options for tuning.
-| Key | Type | Default value | Description |
-| ---------------------- | ----------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| study_name | str | "test-study" | Name of the study. |
-| continue_exising_study | bool | True | Whether to continue an existing study with this name. |
-| use_pruner | bool | True | Whether to use the MedianPruner. |
-| n_trials | int \| None | 15 | Number of trials for each process. `None` represents no limit in terms of numbner of trials. |
-| timeout | int \| None | None | Stop study after the given number of seconds. |
-| params | dict\[str, list\] | {} | Which parameters to tune. The keys should be in the format `key1.key2.key3_`. Type can be one of `[categorical, float, int, longuniform, uniform, subset]`. For more information about the types, visit [Optuna documentation](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html). |
+| Key | Type | Default value | Description |
+| ------------------------ | ----------------- | -------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `study_name` | `str` | `"test-study"` | Name of the study |
+| `continue_exising_study` | `bool` | `True` | Whether to continue an existing study with this name |
+| `use_pruner` | `bool` | `True` | Whether to use the `MedianPruner` |
+| `n_trials` | `int \| None` | `15` | Number of trials for each process. `None` represents no limit in terms of number of trials |
+| `timeout` | `int \| None` | `None` | Stop study after the given number of seconds |
+| `params` | `dict[str, list]` | `{}` | Which parameters to tune. The keys should be in the format `key1.key2.key3_`. Type can be one of `[categorical, float, int, longuniform, uniform, subset]`. For more information about the types, visit [`Optuna` documentation](https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html) |
+
+> \[!NOTE\]
+> `"subset"` sampling is currently only supported for augmentations.
+> You can specify a set of augmentations defined in `trainer` to choose from.
+> Every run, only a subset of random $N$ augmentations will be active (`is_active` parameter will be `True` for chosen ones and `False` for the rest in the set).
+
+### Storage
-**Note**: "subset" sampling is currently only supported for augmentations. You can specify a set of augmentations defined in `trainer` to choose from and every run subset of random N augmentations will be active (`is_active` parameter will be True for chosen ones and False for the rest in the set).
+| Key | Type | Default value | Description |
+| -------------- | ---------------------------- | ------------- | --------------------------------------------------- |
+| `active` | `bool` | `True` | Whether to use storage to make the study persistent |
+| `storage_type` | `Literal["local", "remote"]` | `"local"` | Type of the storage |
-Example of params for tuner block:
+**Example:**
```yaml
-tuner:
+t uner:
+ study_name: "seg_study"
+ n_trials: 10
+ storage:
+ storage_type: "local"
params:
trainer.optimizer.name_categorical: ["Adam", "SGD"]
trainer.optimizer.params.lr_float: [0.0001, 0.001]
trainer.batch_size_int: [4, 16, 4]
+ # each run will have 2 of the following augmentations active
trainer.preprocessing.augmentations_subset: [["Defocus", "Sharpen", "Flip"], 2]
```
-### Storage
-
-| Key | Type | Default value | Description |
-| ------------ | ---------------------------- | ------------- | ---------------------------------------------------- |
-| active | bool | True | Whether to use storage to make the study persistent. |
-| storage_type | Literal\["local", "remote"\] | "local" | Type of the storage. |
-
## ENVIRON
A special section of the config file where you can specify environment variables.
For more info on the variables, see [Credentials](../README.md#credentials).
-**NOTE**
-
-This is not a recommended way due to possible leakage of secrets. This section is intended for testing purposes only.
-
-| Key | Type | Default value | Description |
-| ------------------------ | ---------------------------------------------------------- | -------------- | ----------- |
-| AWS_ACCESS_KEY_ID | str \| None | None | |
-| AWS_SECRET_ACCESS_KEY | str \| None | None | |
-| AWS_S3_ENDPOINT_URL | str \| None | None | |
-| MLFLOW_CLOUDFLARE_ID | str \| None | None | |
-| MLFLOW_CLOUDFLARE_SECRET | str \| None | None | |
-| MLFLOW_S3_BUCKET | str \| None | None | |
-| MLFLOW_S3_ENDPOINT_URL | str \| None | None | |
-| MLFLOW_TRACKING_URI | str \| None | None | |
-| POSTGRES_USER | str \| None | None | |
-| POSTGRES_PASSWORD | str \| None | None | |
-| POSTGRES_HOST | str \| None | None | |
-| POSTGRES_PORT | str \| None | None | |
-| POSTGRES_DB | str \| None | None | |
-| LUXONISML_BUCKET | str \| None | None | |
-| LUXONISML_BASE_PATH | str | "~/luxonis_ml" | |
-| LUXONISML_TEAM_ID | str | "offline" | |
-| LUXONISML_TEAM_NAME | str | "offline" | |
-| LOG_LEVEL | Literal\["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"\] | "INFO" | |
+> \[!WARNING\]
+> This is not a recommended way due to possible leakage of secrets!
+> This section is intended for testing purposes only!
+> Use environment variables or `.env` files instead.
+
+| Key | Type | Default value |
+| -------------------------- | ---------------------------------------------------------- | ---------------- |
+| `AWS_ACCESS_KEY_ID` | `str \| None` | `None` |
+| `AWS_SECRET_ACCESS_KEY` | `str \| None` | `None` |
+| `AWS_S3_ENDPOINT_URL` | `str \| None` | `None` |
+| `MLFLOW_CLOUDFLARE_ID` | `str \| None` | `None` |
+| `MLFLOW_CLOUDFLARE_SECRET` | `str \| None` | `None` |
+| `MLFLOW_S3_BUCKET` | `str \| None` | `None` |
+| `MLFLOW_S3_ENDPOINT_URL` | `str \| None` | `None` |
+| `MLFLOW_TRACKING_URI` | `str \| None` | `None` |
+| `POSTGRES_USER` | `str \| None` | `None` |
+| `POSTGRES_PASSWORD` | `str \| None` | `None` |
+| `POSTGRES_HOST` | `str \| None` | `None` |
+| `POSTGRES_PORT` | `str \| None` | `None` |
+| `POSTGRES_DB` | `str \| None` | `None` |
+| `LUXONISML_BUCKET` | `str \| None` | `None` |
+| `LUXONISML_BASE_PATH` | `str` | `"~/luxonis_ml"` |
+| `LUXONISML_TEAM_ID` | `str` | `"offline"` |
+| `LOG_LEVEL` | `Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]` | `"INFO"` |
diff --git a/configs/classification_heavy_model.yaml b/configs/classification_heavy_model.yaml
index 22b590e6..6ef1f443 100644
--- a/configs/classification_heavy_model.yaml
+++ b/configs/classification_heavy_model.yaml
@@ -14,9 +14,9 @@ loader:
trainer:
preprocessing:
train_image_size: [384, 512]
- keep_aspect_ratio: True
+ keep_aspect_ratio: true
normalize:
- active: True
+ active: true
batch_size: 8
epochs: 200
diff --git a/configs/classification_light_model.yaml b/configs/classification_light_model.yaml
index 32f7d96b..6eeba5fd 100644
--- a/configs/classification_light_model.yaml
+++ b/configs/classification_light_model.yaml
@@ -14,9 +14,9 @@ loader:
trainer:
preprocessing:
train_image_size: [384, 512]
- keep_aspect_ratio: True
+ keep_aspect_ratio: true
normalize:
- active: True
+ active: true
batch_size: 8
epochs: 200
diff --git a/configs/complex_model.yaml b/configs/complex_model.yaml
index 1ba11dd1..149530ad 100644
--- a/configs/complex_model.yaml
+++ b/configs/complex_model.yaml
@@ -15,7 +15,7 @@ model:
- RepPANNeck
losses:
- name: EfficientKeypointBboxLoss
+ - name: EfficientKeypointBBoxLoss
metrics:
- name: ObjectKeypointSimilarity
@@ -23,22 +23,22 @@ model:
- name: MeanAveragePrecisionKeypoints
visualizers:
- name: MultiVisualizer
- params:
- visualizers:
- - name: KeypointVisualizer
- params:
- nonvisible_color: blue
- - name: BBoxVisualizer
- params:
- colors:
- person: "#FF5055"
+ - name: MultiVisualizer
+ params:
+ visualizers:
+ - name: KeypointVisualizer
+ params:
+ nonvisible_color: blue
+ - name: BBoxVisualizer
+ params:
+ colors:
+ person: "#FF5055"
- name: SegmentationHead
inputs:
- RepPANNeck
losses:
- name: BCEWithLogitsLoss
+ - name: BCEWithLogitsLoss
metrics:
- name: F1Score
params:
@@ -47,9 +47,9 @@ model:
params:
task: binary
visualizers:
- name: SegmentationVisualizer
- params:
- colors: "#FF5055"
+ - name: SegmentationVisualizer
+ params:
+ colors: "#FF5055"
- name: EfficientBBoxHead
inputs:
@@ -58,18 +58,18 @@ model:
conf_thres: 0.75
iou_thres: 0.45
losses:
- name: AdaptiveDetectionLoss
+ - name: AdaptiveDetectionLoss
metrics:
- name: MeanAveragePrecision
+ - name: MeanAveragePrecision
visualizers:
- name: BBoxVisualizer
+ - name: BBoxVisualizer
tracker:
project_name: coco_test
save_directory: output
- is_tensorboard: True
- is_wandb: False
- is_mlflow: False
+ is_tensorboard: true
+ is_wandb: false
+ is_mlflow: false
loader:
train_view: train
@@ -86,23 +86,23 @@ trainer:
n_sanity_val_steps: 1
profiler: null
- verbose: True
+ verbose: true
batch_size: 8
accumulate_grad_batches: 1
epochs: &epochs 200
n_workers: 8
validation_interval: 10
n_log_images: 8
- skip_last_batch: True
- log_sub_losses: True
+ skip_last_batch: true
+ log_sub_losses: true
save_top_k: 3
preprocessing:
train_image_size: [&height 384, &width 384]
- keep_aspect_ratio: True
- train_rgb: True
+ keep_aspect_ratio: true
+ train_rgb: true
normalize:
- active: True
+ active: true
augmentations:
- name: Defocus
params:
@@ -131,6 +131,7 @@ trainer:
mode: min
verbose: true
- name: ExportOnTrainEnd
+ - name: ArchiveOnTrainEnd
- name: TestOnTrainEnd
optimizer:
@@ -138,7 +139,7 @@ trainer:
params:
lr: 0.02
momentum: 0.937
- nesterov: True
+ nesterov: true
weight_decay: 0.0005
scheduler:
diff --git a/configs/detection_heavy_model.yaml b/configs/detection_heavy_model.yaml
index f35c1ed3..294034c2 100644
--- a/configs/detection_heavy_model.yaml
+++ b/configs/detection_heavy_model.yaml
@@ -14,9 +14,9 @@ loader:
trainer:
preprocessing:
train_image_size: [384, 512]
- keep_aspect_ratio: True
+ keep_aspect_ratio: true
normalize:
- active: True
+ active: true
batch_size: 8
epochs: &epochs 200
diff --git a/configs/detection_light_model.yaml b/configs/detection_light_model.yaml
index 1f982d92..aca202bd 100644
--- a/configs/detection_light_model.yaml
+++ b/configs/detection_light_model.yaml
@@ -14,9 +14,9 @@ loader:
trainer:
preprocessing:
train_image_size: [384, 512]
- keep_aspect_ratio: True
+ keep_aspect_ratio: true
normalize:
- active: True
+ active: true
batch_size: 8
epochs: &epochs 200
diff --git a/configs/example_export.yaml b/configs/example_export.yaml
index 78f1c650..ff9b1f3d 100644
--- a/configs/example_export.yaml
+++ b/configs/example_export.yaml
@@ -15,9 +15,9 @@ loader:
trainer:
preprocessing:
train_image_size: [384, 512]
- keep_aspect_ratio: True
+ keep_aspect_ratio: true
normalize:
- active: True
+ active: true
batch_size: 8
epochs: &epochs 200
@@ -46,5 +46,5 @@ exporter:
onnx:
opset_version: 11
blobconverter:
- active: True
+ active: true
shaves: 8
diff --git a/configs/example_tuning.yaml b/configs/example_tuning.yaml
index 9e63c877..8e7a6215 100755
--- a/configs/example_tuning.yaml
+++ b/configs/example_tuning.yaml
@@ -2,9 +2,9 @@
model:
- name: segmentation_light
+ name: detection_light
predefined_model:
- name: SegmentationModel
+ name: DetectionModel
params:
variant: light
@@ -15,9 +15,9 @@ loader:
trainer:
preprocessing:
train_image_size: [384, 512]
- keep_aspect_ratio: True
+ keep_aspect_ratio: true
normalize:
- active: True
+ active: true
augmentations:
- name: Defocus
params:
@@ -40,7 +40,7 @@ trainer:
tuner:
- study_name: seg_study
+ study_name: det_study
n_trials: 10
storage:
storage_type: local
diff --git a/configs/keypoint_bbox_heavy_model.yaml b/configs/keypoint_bbox_heavy_model.yaml
index c6b22f35..10527921 100644
--- a/configs/keypoint_bbox_heavy_model.yaml
+++ b/configs/keypoint_bbox_heavy_model.yaml
@@ -14,9 +14,9 @@ loader:
trainer:
preprocessing:
train_image_size: [384, 512]
- keep_aspect_ratio: True
+ keep_aspect_ratio: true
normalize:
- active: True
+ active: true
batch_size: 8
epochs: &epochs 200
diff --git a/configs/keypoint_bbox_light_model.yaml b/configs/keypoint_bbox_light_model.yaml
index a095a551..57042b04 100644
--- a/configs/keypoint_bbox_light_model.yaml
+++ b/configs/keypoint_bbox_light_model.yaml
@@ -14,9 +14,9 @@ loader:
trainer:
preprocessing:
train_image_size: [384, 512]
- keep_aspect_ratio: True
+ keep_aspect_ratio: true
normalize:
- active: True
+ active: true
batch_size: 8
epochs: &epochs 200
diff --git a/configs/segmentation_heavy_model.yaml b/configs/segmentation_heavy_model.yaml
index e9bc16d6..8da7eba8 100644
--- a/configs/segmentation_heavy_model.yaml
+++ b/configs/segmentation_heavy_model.yaml
@@ -14,9 +14,9 @@ loader:
trainer:
preprocessing:
train_image_size: [384, 512]
- keep_aspect_ratio: True
+ keep_aspect_ratio: true
normalize:
- active: True
+ active: true
batch_size: 8
epochs: &epochs 200
diff --git a/configs/segmentation_light_model.yaml b/configs/segmentation_light_model.yaml
index c03703f4..40d38595 100644
--- a/configs/segmentation_light_model.yaml
+++ b/configs/segmentation_light_model.yaml
@@ -14,9 +14,9 @@ loader:
trainer:
preprocessing:
train_image_size: [384, 512]
- keep_aspect_ratio: True
+ keep_aspect_ratio: true
normalize:
- active: True
+ active: true
batch_size: 8
epochs: &epochs 200
diff --git a/luxonis_train/__main__.py b/luxonis_train/__main__.py
index 798a9baa..c0aae2dc 100644
--- a/luxonis_train/__main__.py
+++ b/luxonis_train/__main__.py
@@ -41,6 +41,15 @@ class _ViewType(str, Enum):
),
]
+WeightsType = Annotated[
+ Path | None,
+ typer.Option(
+ help="Path to the model weights.",
+ show_default=False,
+ metavar="FILE",
+ ),
+]
+
ViewType = Annotated[
_ViewType, typer.Option(help="Which dataset view to use.")
]
@@ -77,12 +86,13 @@ def train(
def test(
config: ConfigType = None,
view: ViewType = _ViewType.VAL,
+ weights: WeightsType = None,
opts: OptsType = None,
):
"""Evaluate model."""
from luxonis_train.core import LuxonisModel
- LuxonisModel(config, opts).test(view=view.value)
+ LuxonisModel(config, opts).test(view=view.value, weights=weights)
@app.command()
@@ -94,11 +104,21 @@ def tune(config: ConfigType = None, opts: OptsType = None):
@app.command()
-def export(config: ConfigType = None, opts: OptsType = None):
+def export(
+ config: ConfigType = None,
+ save_path: Annotated[
+ Path | None,
+ typer.Option(help="Path where to save the exported model."),
+ ] = None,
+ weights: WeightsType = None,
+ opts: OptsType = None,
+):
"""Export model."""
from luxonis_train.core import LuxonisModel
- LuxonisModel(config, opts).export()
+ LuxonisModel(config, opts).export(
+ onnx_save_path=save_path, weights=weights
+ )
@app.command()
@@ -107,13 +127,17 @@ def infer(
view: ViewType = _ViewType.VAL,
save_dir: SaveDirType = None,
source_path: SourcePathType = None,
+ weights: WeightsType = None,
opts: OptsType = None,
):
"""Run inference."""
from luxonis_train.core import LuxonisModel
LuxonisModel(config, opts).infer(
- view=view.value, save_dir=save_dir, source_path=source_path
+ view=view.value,
+ save_dir=save_dir,
+ source_path=source_path,
+ weights=weights,
)
@@ -138,7 +162,8 @@ def inspect(
"-s",
help=(
"Multiplier for the image size. "
- "By default the images are shown in their original size."
+ "By default the images are shown in their original size. "
+ "Use this option to scale them."
),
show_default=False,
),
@@ -223,19 +248,20 @@ def inspect(
@app.command()
def archive(
+ config: ConfigType = None,
executable: Annotated[
str | None,
typer.Option(
help="Path to the model file.", show_default=False, metavar="FILE"
),
] = None,
- config: ConfigType = None,
+ weights: WeightsType = None,
opts: OptsType = None,
):
"""Generate NN archive."""
from luxonis_train.core import LuxonisModel
- LuxonisModel(str(config), opts).archive(executable)
+ LuxonisModel(str(config), opts).archive(path=executable, weights=weights)
def version_callback(value: bool):
diff --git a/luxonis_train/attached_modules/losses/README.md b/luxonis_train/attached_modules/losses/README.md
index a8a982ba..724174c7 100644
--- a/luxonis_train/attached_modules/losses/README.md
+++ b/luxonis_train/attached_modules/losses/README.md
@@ -4,97 +4,97 @@ List of all the available loss functions.
## Table Of Contents
-- [CrossEntropyLoss](#crossentropyloss)
-- [BCEWithLogitsLoss](#bcewithlogitsloss)
-- [SmoothBCEWithLogitsLoss](#smoothbcewithlogitsloss)
-- [SigmoidFocalLoss](#sigmoidfocalloss)
-- [SoftmaxFocalLoss](#softmaxfocalloss)
-- [AdaptiveDetectionLoss](#adaptivedetectionloss)
-- [EfficientKeypointBBoxLoss](#efficientkeypointbboxloss)
+- [`CrossEntropyLoss`](#crossentropyloss)
+- [`BCEWithLogitsLoss`](#bcewithlogitsloss)
+- [`SmoothBCEWithLogitsLoss`](#smoothbcewithlogitsloss)
+- [`SigmoidFocalLoss`](#sigmoidfocalloss)
+- [`SoftmaxFocalLoss`](#softmaxfocalloss)
+- [`AdaptiveDetectionLoss`](#adaptivedetectionloss)
+- [`EfficientKeypointBBoxLoss`](#efficientkeypointbboxloss)
-## CrossEntropyLoss
+## `CrossEntropyLoss`
Adapted from [here](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| --------------- | -------------------------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| weight | list\[float\] \| None | None | A manual rescaling weight given to each class. If given, it has to be a list of the same length as there are classes. |
-| reduction | Literal\["none", "mean", "sum"\] | "mean" | Specifies the reduction to apply to the output. |
-| label_smoothing | float\[0.0, 1.0\] | 0.0 | Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets become a mixture of the original ground truth and a uniform distribution as described in [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567). |
+| Key | Type | Default value | Description |
+| ----------------- | -------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `weight` | `list[float] \| None` | `None` | A manual rescaling weight given to each class. If given, it has to be a list of the same length as there are classes |
+| `reduction` | `Literal["none", "mean", "sum"]` | `"mean"` | Specifies the reduction to apply to the output |
+| `label_smoothing` | `float` $\\in \[0.0, 1.0\]$ | `0.0` | Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets become a mixture of the original ground truth and a uniform distribution as described in [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567) |
-## BCEWithLogitsLoss
+## `BCEWithLogitsLoss`
Adapted from [here](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ---------- | -------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------ |
-| weight | list\[float\] \| None | None | A manual rescaling weight given to each class. If given, has to be a list of the same length as there are classes. |
-| reduction | Literal\["none", "mean", "sum"\] | "mean" | Specifies the reduction to apply to the output. |
-| pos_weight | Tensor \| None | None | A weight of positive examples to be broadcasted with target. |
+| Key | Type | Default value | Description |
+| ------------ | -------------------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------- |
+| `weight` | `list[float] \| None` | `None` | A manual rescaling weight given to each class. If given, has to be a list of the same length as there are classes |
+| `reduction` | `Literal["none", "mean", "sum"]` | `"mean"` | Specifies the reduction to apply to the output |
+| `pos_weight` | `Tensor \| None` | `None` | A weight of positive examples to be broadcasted with target |
-## SmoothBCEWithLogitsLoss
+## `SmoothBCEWithLogitsLoss`
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| --------------- | -------------------------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| weight | list\[float\] \| None | None | A manual rescaling weight given to each class. If given, has to be a list of the same length as there are classes. |
-| reduction | Literal\["none", "mean", "sum"\] | "mean" | Specifies the reduction to apply to the output. |
-| label_smoothing | float\[0.0, 1.0\] | 0.0 | Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets become a mixture of the original ground truth and a uniform distribution as described in [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567). |
-| bce_pow | float | 1.0 | Weight for the positive samples. |
+| Key | Type | Default value | Description |
+| ----------------- | -------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `weight` | `list[float] \| None` | `None` | A manual rescaling weight given to each class. If given, has to be a list of the same length as there are classes |
+| `reduction` | `Literal["none", "mean", "sum"]` | `"mean"` | Specifies the reduction to apply to the output |
+| `label_smoothing` | `float` $\\in \[0.0, 1.0\]$ | `0.0` | Specifies the amount of smoothing when computing the loss, where 0.0 means no smoothing. The targets become a mixture of the original ground truth and a uniform distribution as described in [Rethinking the Inception Architecture for Computer Vision](https://arxiv.org/abs/1512.00567) |
+| `bce_pow` | `float` | `1.0` | Weight for the positive samples |
-## SigmoidFocalLoss
+## `SigmoidFocalLoss`
Adapted from [here](https://pytorch.org/vision/stable/generated/torchvision.ops.sigmoid_focal_loss.html#torchvision.ops.sigmoid_focal_loss).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| --------- | -------------------------------- | ------------- | ------------------------------------------------------------------------------------------ |
-| alpha | float | 0.25 | Weighting factor in range (0,1) to balance positive vs negative examples or -1 for ignore. |
-| gamma | float | 2.0 | Exponent of the modulating factor $(1 - p_t)$ to balance easy vs hard examples |
-| reduction | Literal\["none", "mean", "sum"\] | "mean" | Specifies the reduction to apply to the output. |
+| Key | Type | Default value | Description |
+| ----------- | -------------------------------- | ------------- | ------------------------------------------------------------------------------------------- |
+| `alpha` | `float` | `0.25` | Weighting factor in range $(0,1)$ to balance positive vs negative examples or -1 for ignore |
+| `gamma` | `float` | `2.0` | Exponent of the modulating factor $(1 - p_t)$ to balance easy vs hard examples |
+| `reduction` | `Literal["none", "mean", "sum"]` | `"mean"` | Specifies the reduction to apply to the output |
-## SoftmaxFocalLoss
+## `SoftmaxFocalLoss`
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| --------- | -------------------------------- | ------------- | ----------------------------------------------------------------------------- |
-| alpha | float \| list | 0.25 | Either a float for all channels or list of alphas for each channel. |
-| gamma | float | 2.0 | Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. |
-| reduction | Literal\["none", "mean", "sum"\] | "mean" | Specifies the reduction to apply to the output. |
+| Key | Type | Default value | Description |
+| ----------- | -------------------------------- | ------------- | ------------------------------------------------------------------------------ |
+| `alpha` | `float \| list` | `0.25` | Either a float for all channels or list of alphas for each channel |
+| `gamma` | `float` | `2.0` | Exponent of the modulating factor $(1 - p_t)$ to balance easy vs hard examples |
+| `reduction` | `Literal["none", "mean", "sum"]` | `"mean"` | Specifies the reduction to apply to the output |
-## AdaptiveDetectionLoss
+## `AdaptiveDetectionLoss`
Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ----------------- | ------------------------------------------------- | ------------- | ----------------------------------------------------------------------------------- |
-| n_warmup_epochs | int | 4 | Number of epochs where ATSS assigner is used, after that we switch to TAL assigner. |
-| iou_type | Literal\["none", "giou", "diou", "ciou", "siou"\] | "giou" | IoU type used for bbox regression loss. |
-| class_loss_weight | float | 1.0 | Weight used for the classification part of the loss. |
-| iou_loss_weight | float | 2.5 | Weight used for the IoU part of the loss. |
+| Key | Type | Default value | Description |
+| ------------------- | ------------------------------------------------- | ------------- | -------------------------------------------------------------------------------------- |
+| `n_warmup_epochs` | `int` | `4` | Number of epochs where `ATSS` assigner is used, after that we switch to `TAL` assigner |
+| `iou_type` | `Literal["none", "giou", "diou", "ciou", "siou"]` | `"giou"` | `IoU` type used for bounding box regression loss |
+| `class_loss_weight` | `float` | `1.0` | Weight used for the classification part of the loss |
+| `iou_loss_weight` | `float` | `2.5` | Weight used for the `IoU` part of the loss |
-## EfficientKeypointBBoxLoss
+## `EfficientKeypointBBoxLoss`
Adapted from [YOLO-Pose: Enhancing YOLO for Multi Person Pose Estimation Using Object
Keypoint Similarity Loss](https://arxiv.org/ftp/arxiv/papers/2204/2204.06806.pdf).
-| Key | Type | Default value | Description |
-| --------------------- | ------------------------------------------------- | ------------- | --------------------------------------------------------------------------------------------------- |
-| n_warmup_epochs | int | 4 | Number of epochs where ATSS assigner is used, after that we switch to TAL assigner. |
-| iou_type | Literal\["none", "giou", "diou", "ciou", "siou"\] | "giou" | IoU type used for bbox regression sub-loss |
-| reduction | Literal\["mean", "sum"\] | "mean" | Specifies the reduction to apply to the output. |
-| class_loss_weight | float | 1.0 | Weight used for the classification sub-loss. |
-| iou_loss_weight | float | 2.5 | Weight used for the IoU sub-loss. |
-| regr_kpts_loss_weight | float | 1.5 | Weight used for the OKS sub-loss. |
-| vis_kpts_loss_weight | float | 1.0 | Weight used for the keypoint visibility sub-loss. |
-| sigmas | list\[float\] \\ None | None | Sigmas used in KeypointLoss for OKS metric. If None then use COCO ones if possible or default ones. |
-| area_factor | float \| None | None | Factor by which we multiply bbox area which is used in KeypointLoss. If None then use default one. |
+| Key | Type | Default value | Description |
+| ----------------------- | ------------------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------- |
+| `n_warmup_epochs` | `int` | `4` | Number of epochs where `ATSS` assigner is used, after that we switch to `TAL` assigner |
+| `iou_type` | `Literal["none", "giou", "diou", "ciou", "siou"]` | `"giou"` | `IoU` type used for bounding box regression sub-loss |
+| `reduction` | `Literal["mean", "sum"]` | `"mean"` | Specifies the reduction to apply to the output |
+| `class_loss_weight` | `float` | `1.0` | Weight used for the classification sub-loss |
+| `iou_loss_weight` | `float` | `2.5` | Weight used for the `IoU` sub-loss |
+| `regr_kpts_loss_weight` | `float` | `1.5` | Weight used for the `OKS` sub-loss |
+| `vis_kpts_loss_weight` | `float` | `1.0` | Weight used for the keypoint visibility sub-loss |
+| `sigmas` | `list[float] \ None` | `None` | Sigmas used in `KeypointLoss` for `OKS` metric. If `None` then use COCO ones if possible or default ones |
+| `area_factor` | `float \| None` | `None` | Factor by which we multiply bounding box area which is used in `KeypointLoss.` If `None` then use default one |
diff --git a/luxonis_train/attached_modules/metrics/README.md b/luxonis_train/attached_modules/metrics/README.md
index 17735540..b61f4843 100644
--- a/luxonis_train/attached_modules/metrics/README.md
+++ b/luxonis_train/attached_modules/metrics/README.md
@@ -23,6 +23,14 @@ Metrics from the [`torchmetrics`](https://lightning.ai/docs/torchmetrics/stable/
For more information, see [object-keypoint-similarity](https://learnopencv.com/object-keypoint-similarity/).
+**Params**
+
+| Key | Type | Default value | Description |
+| ------------------ | --------------------- | ------------- | --------------------------------------------------------------------- |
+| `sigmas` | `list[float] \| None` | `None` | List of sigmas for each keypoint. If `None`, the COCO sigmas are used |
+| `area_factor` | `float` | `0.53` | Factor by which to multiply the bounding box area |
+| `use_cocoeval_oks` | `bool` | `True` | Whether to use the same OKS formula as in COCO evaluation |
+
## MeanAveragePrecision
Compute the `Mean-Average-Precision (mAP) and Mean-Average-Recall (mAR)` for object detection predictions.
@@ -43,4 +51,13 @@ boxes.
Similar to [MeanAveragePrecision](#meanaverageprecision), but uses [OKS](#objectkeypointsimilarity) as `IoU` measure.
For a deeper understanding of how OKS works, please refer to the detailed explanation provided [here](https://learnopencv.com/object-keypoint-similarity/).
-Evaluation leverages COCO evaluation framework (COCOeval) to assess mAP performance.
+Evaluation leverages COCO evaluation framework (COCOeval) to assess mAP performance.
+
+**Params**
+
+| Key | Type | Default value | Description |
+| ------------- | ----------------------------------- | ------------- | --------------------------------------------------------------------- |
+| `sigmas` | `list[float] \| None` | `None` | List of sigmas for each keypoint. If `None`, the COCO sigmas are used |
+| `area_factor` | `float` | `0.53` | Factor by which to multiply the bounding box area |
+| `max_dets` | `int` | `20` | Maximum number of detections per image |
+| `box_fotmat` | `Literal["xyxy", "xywh", "cxcywh"]` | `"xyxy"` | Format of the bounding boxes |
diff --git a/luxonis_train/attached_modules/visualizers/README.md b/luxonis_train/attached_modules/visualizers/README.md
index 8bedaed9..1fca42e2 100644
--- a/luxonis_train/attached_modules/visualizers/README.md
+++ b/luxonis_train/attached_modules/visualizers/README.md
@@ -1,87 +1,88 @@
# Visualizers
+Visualizers are used to render the output of a node. They are used in the `visualizers` field of the `Node` configuration.
+
## Table Of Contents
-- [BBoxVisualizer](#bboxvisualizer)
-- [ClassificationVisualizer](#classificationvisualizer)
-- [KeypointVisualizer](#keypointvisualizer)
-- [SegmentationVisualizer](#segmentationvisualizer)
-- [MultiVisualizer](#multivisualizer)
+- [`BBoxVisualizer`](#bboxvisualizer)
+- [`ClassificationVisualizer`](#classificationvisualizer)
+- [`KeypointVisualizer`](#keypointvisualizer)
+- [`MultiVisualizer`](#multivisualizer)
-## BBoxVisualizer
+## `BBoxVisualizer`
Visualizer for bounding boxes.
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| --------- | ------------------------------------------------------------------------------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| labels | dict\[int, str\] \| list\[str\] \| None | None | Either a dictionary mapping class indices to names, or a list of names. If list is provided, the label mapping is done by index. By default, no labels are drawn. |
-| colors | dict\[int, tuple\[int, int, int\] \| str\] \| list\[tuple\[int, int, int\] \| str\] \| None | None | Colors to use for the boundig boxes. Either a dictionary mapping class names to colors, or a list of colors. |
-| fill | bool | False | Whether or not to fill the bounding boxes. |
-| width | int | 1 | The width of the bounding box lines. |
-| font | str \| None | None | A filename containing a TrueType font. |
-| font_size | int \| None | None | Font size used for the labels. |
+| Key | Type | Default value | Description |
+| ----------- | ------------------------------------------------------------------------------------- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `labels` | `dict[int, str] \| list[str] \| None` | `None` | Either a dictionary mapping class indices to names, or a list of names. If list is provided, the label mapping is done by index. By default, no labels are drawn |
+| `colors` | `dict[int, tuple[int, int, int] \| str] \| list[tuple[int, int, int] \| str] \| None` | `None` | Colors to use for the bounding boxes. Either a dictionary mapping class names to colors, or a list of colors. Color can be either a tuple of RGB values or a hex string |
+| `fill` | `bool` | `False` | Whether to fill the bounding boxes |
+| `width` | `int` | `1` | The width of the bounding box lines |
+| `font` | `str \| None` | `None` | A filename containing a `TrueType` font |
+| `font_size` | `int \| None` | `None` | Font size used for the labels |
-**Example**
+**Example:**
-![bbox_viz_example](https://github.com/luxonis/luxonis-train/blob/main/media/example_viz/bbox.png)
+![bounding_box_viz_example](https://github.com/luxonis/luxonis-train/blob/main/media/example_viz/bbox.png)
-## KeypointVisualizer
+## `KeypointVisualizer`
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| -------------------- | -------------------------------------- | ------------- | -------------------------------------------------------------------------------------------------------------------------------- |
-| visibility_threshold | float | 0.5 | Threshold for visibility of keypoints. If the visibility of a keypoint is below this threshold, it is considered as not visible. |
-| connectivity | list\[tuple\[int, int\]\] \| None | None | List of tuples of keypoint indices that define the connections in the skeleton. |
-| visible_color | str \| tuple\[int, int, int\] | "red" | Color of visible keypoints. |
-| nonvisible_color | str \| tuple\[int, int, int \] \| None | None | Color of nonvisible keypoints. If None, nonvisible keypoints are not drawn. |
+| Key | Type | Default value | Description |
+| ---------------------- | -------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------- |
+| `visibility_threshold` | `float` | `0.5` | Threshold for visibility of keypoints. If the visibility of a keypoint is below this threshold, it is considered as not visible |
+| `connectivity` | `list[tuple[int, int]] \| None` | `None` | List of tuples of keypoint indices that define the connections in the skeleton |
+| `visible_color` | `str \| tuple[int, int, int]` | `"red"` | Color of visible keypoints |
+| `nonvisible_color` | `str \| tuple[int, int, int ] \| None` | `None` | Color of non-visible keypoints. If `None`, non-visible keypoints are not drawn |
-**Example**
+**Example:**
-![kpt_viz_example](https://github.com/luxonis/luxonis-train/blob/main/media/example_viz/kpts.png)
+![keypoints_viz_example](https://github.com/luxonis/luxonis-train/blob/main/media/example_viz/kpts.png)
-## SegmentationVisualizer
+## `SegmentationVisualizer`
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ----- | ----------------------------- | ------------- | -------------------------------------- |
-| color | str \| tuple\[int, int, int\] | #5050FF | Color of the segmentation masks. |
-| alpha | float | 0.6 | Alpha value of the segmentation masks. |
+| Key | Type | Default value | Description |
+| ------- | ----------------------------- | ------------- | ------------------------------------- |
+| `color` | `str \| tuple[int, int, int]` | `"#5050FF"` | Color of the segmentation masks |
+| `alpha` | `float` | `0.6` | Alpha value of the segmentation masks |
-**Example**
+**Example:**
-![seg_viz_example](https://github.com/luxonis/luxonis-train/blob/main/media/example_viz/segmentation.png)
+![segmentation_viz_example](https://github.com/luxonis/luxonis-train/blob/main/media/example_viz/segmentation.png)
-## ClassificationVisualizer
+## `ClassificationVisualizer`
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ------------ | ---------------------- | ------------- | -------------------------------------------------------------------------- |
-| include_plot | bool | True | Whether to include a plot of the class probabilities in the visualization. |
-| color | tuple\[int, int, int\] | (255, 0, 0) | Color of the text. |
-| font_scale | float | 1.0 | Scale of the font. |
-| thickness | int | 1 | Line thickness of the font. |
+| Key | Type | Default value | Description |
+| -------------- | ---------------------- | ------------- | ------------------------------------------------------------------------- |
+| `include_plot` | `bool` | `True` | Whether to include a plot of the class probabilities in the visualization |
+| `color` | `tuple[int, int, int]` | `(255, 0, 0)` | Color of the text |
+| `font_scale` | `float` | `1.0` | Scale of the font |
+| `thickness` | `int` | `1` | Line thickness of the font |
-**Example**
+**Example:**
![class_viz_example](https://github.com/luxonis/luxonis-train/blob/main/media/example_viz/class.png)
-## MultiVisualizer
+## `MultiVisualizer`
Special type of meta-visualizer that combines several visualizers into one. The combined visualizers share canvas.
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ----------- | ------------ | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
-| visualizers | list\[dict\] | \[ \] | List of visualizers to combine. Each item in the list is a dictionary with the following keys: - name (str): Name of the visualizer. Must be a key in the VISUALIZERS registry. - params (dict): Parameters to pass to the visualizer. |
+| Key | Type | Default value | Description |
+| ------------- | ------------ | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
+| `visualizers` | `list[dict]` | `[]` | List of visualizers to combine. Each item in the list is a dictionary with the following keys: - `"name"` (`str`): Name of the visualizer. Must be a key in the `VISUALIZERS` registry. - `"params"` (`dict`): Parameters to pass to the visualizer |
-**Example**
+**Example:**
-Example of combining [KeypointVisualizer](#keypointvisualizer) and [BBoxVisualizer](#bboxvisualizer).
+Example of combining [`KeypointVisualizer`](#keypointvisualizer) and [`BBoxVisualizer`](#bboxvisualizer).
![multi_viz_example](https://github.com/luxonis/luxonis-train/blob/main/media/example_viz/multi.png)
diff --git a/luxonis_train/callbacks/README.md b/luxonis_train/callbacks/README.md
index dc015ccd..64fbdf4f 100644
--- a/luxonis_train/callbacks/README.md
+++ b/luxonis_train/callbacks/README.md
@@ -4,54 +4,64 @@ List of all supported callbacks.
## Table Of Contents
-- [PytorchLightning Callbacks](#pytorchlightning-callbacks)
-- [ExportOnTrainEnd](#exportontrainend)
-- [LuxonisProgressBar](#luxonisprogressbar)
-- [MetadataLogger](#metadatalogger)
-- [TestOnTrainEnd](#testontrainend)
-- [UploadCheckpoint](#uploadcheckpoint)
+- [`PytorchLightning` Callbacks](#pytorchlightning-callbacks)
+- [`ExportOnTrainEnd`](#exportontrainend)
+- [`ArchiveOnTrainEnd`](#archiveontrainend)
+- [`MetadataLogger`](#metadatalogger)
+- [`TestOnTrainEnd`](#testontrainend)
+- [`UploadCheckpoint`](#uploadcheckpoint)
-## PytorchLightning Callbacks
+## `PytorchLightning` Callbacks
List of supported callbacks from `lightning.pytorch`.
-- [GPUStatsMonitor](https://pytorch-lightning.readthedocs.io/en/1.5.10/api/pytorch_lightning.callbacks.gpu_stats_monitor.html)
-- [DeviceStatsMonitor](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.DeviceStatsMonitor.html#lightning.pytorch.callbacks.DeviceStatsMonitor)
-- [EarlyStopping](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping)
-- [LearningRateMonitor](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html#lightning.pytorch.callbacks.LearningRateMonitor)
-- [ModelCheckpoint](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#lightning.pytorch.callbacks.ModelCheckpoint)
-- [RichModelSummary](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html#lightning.pytorch.callbacks.RichModelSummary)
+- [`GPUStatsMonitor`](https://pytorch-lightning.readthedocs.io/en/1.5.10/api/pytorch_lightning.callbacks.gpu_stats_monitor.html)
+- [`DeviceStatsMonitor`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.DeviceStatsMonitor.html#lightning.pytorch.callbacks.DeviceStatsMonitor)
+- [`EarlyStopping`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.EarlyStopping.html#lightning.pytorch.callbacks.EarlyStopping)
+- [`LearningRateMonitor`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.LearningRateMonitor.html#lightning.pytorch.callbacks.LearningRateMonitor)
+- [`ModelCheckpoint`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#lightning.pytorch.callbacks.ModelCheckpoint)
+- [`RichModelSummary`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichModelSummary.html#lightning.pytorch.callbacks.RichModelSummary)
+- [`GradientAccumulationScheduler`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.GradientAccumulationScheduler.html#lightning.pytorch.callbacks.GradientAccumulationScheduler)
+- [`StochasticWeightAveraging`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.StochasticWeightAveraging.html#lightning.pytorch.callbacks.StochasticWeightAveraging)
+- [`Timer`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.Timer.html#lightning.pytorch.callbacks.Timer)
+- [`ModelPruning`](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelPruning.html#lightning.pytorch.callbacks.ModelPruning)
-## ExportOnTrainEnd
+## `ExportOnTrainEnd`
Performs export on train end with best weights.
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| -------------------- | --------------------------- | ------------- | ----------------------------------------------------------------------------------------------------------- |
-| preferred_checkpoint | Literal\["metric", "loss"\] | metric | Which checkpoint should we use. If preferred is not available then try to use the other one if its present. |
+| Key | Type | Default value | Description |
+| ---------------------- | --------------------------- | ------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `preferred_checkpoint` | `Literal["metric", "loss"]` | `"metric"` | Which checkpoint should the callback use. If the preferred checkpoint is not available, the other option is used. If none is available, the callback is skipped |
-## LuxonisProgressBar
+## `ArchiveOnTrainEnd`
-Custom rich text progress bar based on RichProgressBar from Pytorch Lightning.
+Callback to create an `NN Archive` at the end of the training.
-## MetadataLogger
+**Parameters:**
+
+| Key | Type | Default value | Description |
+| ---------------------- | --------------------------- | ------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `preferred_checkpoint` | `Literal["metric", "loss"]` | `"metric"` | Which checkpoint should the callback use. If the preferred checkpoint is not available, the other option is used. If none is available, the callback is skipped |
+
+## `MetadataLogger`
Callback that logs training metadata.
Metadata include all defined hyperparameters together with git hashes of `luxonis-ml` and `luxonis-train` packages. Also stores this information locally.
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ----------- | ----------- | ------------- | ----------------------------------------------------------------------------------------------------------------------- |
-| hyperparams | list\[str\] | \[\] | List of hyperparameters to log. The hyperparameters are provided as config keys in dot notation. E.g. "trainer.epochs". |
+| Key | Type | Default value | Description |
+| ------------- | ----------- | ------------- | -------------------------------------------------------------------------------------------------------------------------- |
+| `hyperparams` | `list[str]` | `[]` | List of hyperparameters to log. The hyperparameters are provided as config keys in dot notation. _E.g._ `"trainer.epochs"` |
-## TestOnTrainEnd
+## `TestOnTrainEnd`
Callback to perform a test run at the end of the training.
-## UploadCheckpoint
+## `UploadCheckpoint`
-Callback that uploads currently best checkpoint (based on validation loss) to the tracker location - where all other logs are stored.
+Callback that uploads currently the best checkpoint (based on validation loss) to the tracker location - where all other logs are stored.
diff --git a/luxonis_train/config/predefined_models/README.md b/luxonis_train/config/predefined_models/README.md
index 3733534d..27976b3a 100644
--- a/luxonis_train/config/predefined_models/README.md
+++ b/luxonis_train/config/predefined_models/README.md
@@ -1,150 +1,150 @@
# Predefined models
-In addition to definig the model by hand, we offer a list of simple predefined
+In addition to defining the model by hand, we offer a list of simple predefined
models which can be used instead.
## Table Of Contents
-- [SegmentationModel](#segmentationmodel)
-- [DetectionModel](#detectionmodel)
-- [KeypointDetectionModel](#keypointdetectionmodel)
-- [ClassificationModel](#classificationmodel)
-
-**Params**
-
-| Key | Type | Default value | Description |
-| ------------------- | ---------------- | ------------- | --------------------------------------------------------------------- |
-| name | str | | Name of the predefined architecture. See below the available options. |
-| params | dict\[str, Any\] | {} | Additional parameters of the predefined model. |
-| include_nodes | bool | True | Whether to include nodes of the model. |
-| include_losses | bool | True | Whether to include loss functions. |
-| include_metrics | bool | True | Whether to include metrics. |
-| include_visualizers | bool | True | Whether to include visualizers. |
-
-## SegmentationModel
-
-The `SegmentationModel` allows for both "light" and "heavy" variants, where the "heavy" variant is more accurate, and the "light" variant is faster.
-
-See an example configuration file using this predefined model [here](../../../configs/segmentation_light_model.yaml) for the "light" variant, and [here](../../../configs/segmentation_heavy_model.yaml) for the "heavy" variant.
-
-**Components**
-
-| Name | Alias | Function |
-| --------------------------------------------------------------------------------------------- | -------------------------- | -------------------------------------------------------------------------------------------- |
-| [DDRNet](../../nodes/README.md#ddrnet) | segmentation_backbone | Backbone of the model. Available variants: "light" (DDRNet-23-slim) and "heavy" (DDRNet-23). |
-| [SegmentationHead](../../nodes/README.md#segmentationhead) | segmentation_head | Head of the model. |
-| [BCEWithLogitsLoss](../../attached_modules/losses/README.md#bcewithlogitsloss) | segmentation_loss | Loss of the model when the task is set to "binary". |
-| [CrossEntropyLoss](../../attached_modules/losses/README.md#crossentropyloss) | segmentation_loss | Loss of the model when the task is set to "multiclass" or "multilabel". |
-| [JaccardIndex](../../attached_modules/metrics/README.md#torchmetrics) | segmentation_jaccard_index | Main metric of the model. |
-| [F1Score](../../attached_modules/metrics/README.md#torchmetrics) | segmentation_f1_score | Secondary metric of the model. |
-| [SegmentationVisualizer](../../attached_modules/visualizers/README.md#segmentationvisualizer) | segmentation_visualizer | Visualizer of the `SegmentationHead`. |
-
-**Params**
-
-| Key | Type | Default value | Description |
-| ----------------- | --------------------------------- | ------------- | ------------------------------------------------------------------------------------------------ |
-| variant | Literal\["light", "heavy"\] | "light" | Defines the variant of the model. "light" uses DDRNet-23-slim, "heavy" uses DDRNet-23. |
-| backbone | str | "DDRNet" | Name of the node to be used as a backbone. |
-| backbone_params | dict | {} | Additional parameters for the backbone. If not provided, variant-specific defaults will be used. |
-| head_params | dict | {} | Additional parameters for the head. |
-| aux_head_params | dict | {} | Additional parameters for auxiliary heads. |
-| loss_params | dict | {} | Additional parameters for the loss. |
-| visualizer_params | dict | {} | Additional parameters for the visualizer. |
-| task | Literal\["binary", "multiclass"\] | "binary" | Type of the task of the model. |
-| task_name | str \| None | None | Custom task name for the head. |
-
-## DetectionModel
-
-The `DetectionModel` allows for both "light" and "heavy" variants, where the "heavy" variant is more accurate, and the "light" variant is faster.
-
-See an example configuration file using this predefined model [here](../../../configs/detection_light_model.yaml) for the "light" variant, and [here](../../../configs/detection_heavy_model.yaml) for the "heavy" variant.
-
-**Components**
-
-| Name | Alias | Function |
-| -------------------------------------------------------------------------------------- | -------------------- | ------------------------------------------------------------------------------------------------- |
-| [EfficientRep](../../nodes/README.md#efficientrep) | detection_backbone | Backbone of the model. Available variants: "light" (EfficientRep-N) and "heavy" (EfficientRep-L). |
-| [RepPANNeck](../../nodes/README.md#reppanneck) | detection_neck | Neck of the model. |
-| [EfficientBBoxHead](../../nodes/README.md#efficientbboxhead) | detection_head | Head of the model. |
-| [AdaptiveDetectionLoss](../../attached_modules/losses/README.md#adaptivedetectionloss) | detection_loss | Loss of the model. |
-| [MeanAveragePrecision](../../attached_modules/metrics/README.md#meanaverageprecision) | detection_map | Main metric of the model. |
-| [BBoxVisualizer](../../attached_modules/visualizers/README.md#bboxvisualizer) | detection_visualizer | Visualizer of the `detection_head`. |
-
-**Params**
-
-| Key | Type | Default value | Description |
-| ----------------- | --------------------------- | -------------- | ------------------------------------------------------------------------------------------- |
-| variant | Literal\["light", "heavy"\] | "light" | Defines the variant of the model. "light" uses EfficientRep-N, "heavy" uses EfficientRep-L. |
-| use_neck | bool | True | Whether to include the neck in the model. |
-| backbone | str | "EfficientRep" | Name of the node to be used as a backbone. |
-| backbone_params | dict | {} | Additional parameters to the backbone. |
-| neck_params | dict | {} | Additional parameters to the neck. |
-| head_params | dict | {} | Additional parameters to the head. |
-| loss_params | dict | {} | Additional parameters to the loss. |
-| visualizer_params | dict | {} | Additional parameters to the visualizer. |
-| task_name | str \| None | None | Custom task name for the head. |
-
-## KeypointDetectionModel
-
-The `KeypointDetectionModel` allows for both "light" and "heavy" variants, where the "heavy" variant is more accurate, and the "light" variant is faster.
-
-See an example configuration file using this predefined model [here](../../../configs/keypoint_bbox_light_model.yaml) for the "light" variant, and [here](../../../configs/keypoint_bbox_heavy_model.yaml) for the "heavy" variant.
-
-**Components**
-
-| Name | Alias | Function |
-| ------------------------------------------------------------------------------------------------------- | ---------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- |
-| [EfficientRep](../../nodes/README.md#efficientrep) | kpt_detection_backbone | Backbone of the model.. Available variants: "light" (EfficientRep-N) and "heavy" (EfficientRep-L). |
-| [RepPANNeck](../../nodes/README.md#reppanneck) | kpt_detection_neck | Neck of the model. |
-| [EfficientKeypointBBoxHead](../../nodes/README.md#efficientkeypointbboxhead) | kpt_detection_head | Head of the model. |
-| [EfficientKeypointBBoxLoss](../../attached_modules/losses/README.md#efficientkeypointbboxloss) | kpt_detection_loss | Loss of the model. |
-| [ObjectKeypointSimilarity](../../attached_modules/metrics/README.md#objectkeypointsimilarity) | kpt_detection_oks | Main metric of the model. |
-| [MeanAveragePrecisionKeypoints](../../attached_modules/metrics/README.md#meanaverageprecisionkeypoints) | kpt_detection_map | Secondary metric of the model. |
-| [BBoxVisualizer](../../attached_modules/visualizers/README.md#bboxvisualizer) | | Visualizer for bounding boxes. Combined with keypoint visualizer in [MultiVisualizer](../../attached_modules/visualizers/README.md#multivisualizer). |
-| [KeypointVisualizer](../../attached_modules/visualizers/README.md#keypointvisualizer) | | Visualizer for keypoints. Combined with keypoint visualizer in [MultiVisualizer](../../attached_modules/visualizers/README.md#multivisualizer) |
-
-**Params**
-
-| Key | Type | Default value | Description |
-| ---------------------- | --------------------------- | -------------- | ------------------------------------------------------------------------------------------- |
-| variant | Literal\["light", "heavy"\] | "light" | Defines the variant of the model. "light" uses EfficientRep-N, "heavy" uses EfficientRep-L. |
-| use_neck | bool | True | Whether to include the neck in the model. |
-| backbone | str | "EfficientRep" | Name of the node to be used as a backbone. |
-| backbone_params | dict | {} | Additional parameters to the backbone. |
-| neck_params | dict | {} | Additional parameters to the neck. |
-| head_params | dict | {} | Additional parameters to the head. |
-| loss_params | dict | {} | Additional parameters to the loss. |
-| kpt_visualizer_params | dict | {} | Additional parameters to the keypoint visualizer. |
-| bbox_visualizer_params | dict | {} | Additional parameters to the bbox visualizer. |
-| bbox_task_name | str \| None | None | Custom task name for the detection head. |
-| kpt_task_name | str \| None | None | Custom task name for the keypoint head. |
-
-## ClassificationModel
-
-The `ClassificationModel` allows for both "light" and "heavy" variants, where the "heavy" variant is more accurate, and the "light" variant is faster. Can be used for multiclass and multilabel tasks.
-
-See an example configuration file using this predefined model [here](../../../configs/classification_light_model.yaml) for the "light" variant, and [here](../../../configs/classification_heavy_model.yaml) for the "heavy" variant.
-
-**Components**
-
-| Name | Alias | Function |
-| ---------------------------------------------------------------------------- | ----------------------- | ----------------------------------------------------------------------------------------------------- |
-| [ResNet](../../nodes/README.md#resnet) | classification_backbone | Backbone of the model. The "light" variant uses ResNet-18, while the "heavy" variant uses ResNet-101. |
-| [ClassificationHead](../../nodes/README.md#classificationhead) | classification_head | Head of the model. |
-| [CrossEntropyLoss](../../attached_modules/losses/README.md#crossentropyloss) | classification_loss | Loss of the model. |
-| [F1Score](../../attached_modules/metrics/README.md#torchmetrics) | classification_f1_score | Main metric of the model. |
-| [Accuracy](../../attached_modules/metrics/README.md#torchmetrics) | classification_accuracy | Secondary metric of the model. |
-| [Recall](../../attached_modules/metrics/README.md#torchmetrics) | classification_recall | Secondary metric of the model. |
-
-**Params**
-
-| Key | Type | Default value | Description |
-| ----------------- | ------------------------------------- | ------------- | ----------------------------------------------------------------------------------- |
-| variant | Literal\["light", "heavy"\] | "light" | Defines the variant of the model. "light" uses ResNet-18, "heavy" uses ResNet-101. |
-| backbone | str | "ResNet" | Name of the node to be used as a backbone. |
-| backbone_params | dict | {} | Additional parameters to the backbone. |
-| head_params | dict | {} | Additional parameters to the head. |
-| loss_params | dict | {} | Additional parameters to the loss. |
-| visualizer_params | dict | {} | Additional parameters to the visualizer. |
-| task | Literal\["multiclass", "multilabel"\] | "multiclass" | Type of the task of the model. |
-| task_name | str \| None | None | Custom task name for the head. |
+- [`SegmentationModel`](#segmentationmodel)
+- [`DetectionModel`](#detectionmodel)
+- [`KeypointDetectionModel`](#keypointdetectionmodel)
+- [`ClassificationModel`](#classificationmodel)
+
+**Parameters:**
+
+| Key | Type | Default value | Description |
+| --------------------- | ---------------- | ------------- | -------------------------------------------------------------------- |
+| `name` | `str` | - | Name of the predefined architecture. See below the available options |
+| `params` | `dict[str, Any]` | `{}` | Additional parameters of the predefined model |
+| `include_nodes` | `bool` | `True` | Whether to include nodes of the model |
+| `include_losses` | `bool` | `True` | Whether to include loss functions |
+| `include_metrics` | `bool` | `True` | Whether to include metrics |
+| `include_visualizers` | `bool` | `True` | Whether to include visualizers |
+
+## `SegmentationModel`
+
+The `SegmentationModel` allows for both `"light"` and `"heavy"` variants, where the `"heavy"` variant is more accurate, and the `"light"` variant is faster.
+
+See an example configuration file using this predefined model [here](../../../configs/segmentation_light_model.yaml) for the `"light"` variant, and [here](../../../configs/segmentation_heavy_model.yaml) for the `"heavy"` variant.
+
+**Components:**
+
+| Name | Alias | Function |
+| ----------------------------------------------------------------------------------------------- | ------------------------------ | --------------------------------------------------------------------------------------------------- |
+| [`DDRNet`](../../nodes/README.md#ddrnet) | `"segmentation_backbone"` | Backbone of the model. Available variants: `"light"` (`DDRNet-23-slim`) and `"heavy"` (`DDRNet-23`) |
+| [`SegmentationHead`](../../nodes/README.md#segmentationhead) | `"segmentation_head"` | Head of the model |
+| [`BCEWithLogitsLoss`](../../attached_modules/losses/README.md#bcewithlogitsloss) | `"segmentation_loss"` | Loss of the model when the task is set to `"binary"` |
+| [`CrossEntropyLoss`](../../attached_modules/losses/README.md#crossentropyloss) | `"segmentation_loss"` | Loss of the model when the task is set to `"multiclass"` or `"multilabel"` |
+| [`JaccardIndex`](../../attached_modules/metrics/README.md#torchmetrics) | `"segmentation_jaccard_index"` | Main metric of the model |
+| [`F1Score`](../../attached_modules/metrics/README.md#torchmetrics) | `"segmentation_f1_score"` | Secondary metric of the model |
+| [`SegmentationVisualizer`](../../attached_modules/visualizers/README.md#segmentationvisualizer) | `"segmentation_visualizer"` | Visualizer of the `SegmentationHead` |
+
+**Parameters:**
+
+| Key | Type | Default value | Description |
+| ------------------- | --------------------------------- | ------------- | ----------------------------------------------------------------------------------------------- |
+| `variant` | `Literal["light", "heavy"]` | `"light"` | Defines the variant of the model. `"light"` uses `DDRNet-23-slim`, `"heavy"` uses `DDRNet-23` |
+| `backbone` | `str` | `"DDRNet"` | Name of the node to be used as a backbone |
+| `backbone_params` | `dict` | `{}` | Additional parameters for the backbone. If not provided, variant-specific defaults will be used |
+| `head_params` | `dict` | `{}` | Additional parameters for the head |
+| `aux_head_params` | `dict` | `{}` | Additional parameters for auxiliary heads |
+| `loss_params` | `dict` | `{}` | Additional parameters for the loss |
+| `visualizer_params` | `dict` | `{}` | Additional parameters for the visualizer |
+| `task` | `Literal["binary", "multiclass"]` | `"binary"` | Type of the task of the model |
+| `task_name` | `str \| None` | `None` | Custom task name for the head |
+
+## `DetectionModel`
+
+The `DetectionModel` allows for both `"light"` and `"heavy"` variants, where the `"heavy"` variant is more accurate, and the `"light"` variant is faster.
+
+See an example configuration file using this predefined model [here](../../../configs/detection_light_model.yaml) for the `"light"` variant, and [here](../../../configs/detection_heavy_model.yaml) for the `"heavy"` variant.
+
+**Components:**
+
+| Name | Alias | Function |
+| ---------------------------------------------------------------------------------------- | ------------------------ | -------------------------------------------------------------------------------------------------------- |
+| [`EfficientRep`](../../nodes/README.md#efficientrep) | `"detection_backbone"` | Backbone of the model. Available variants: `"light"` (`EfficientRep-N`) and `"heavy"` (`EfficientRep-L`) |
+| [`RepPANNeck`](../../nodes/README.md#reppanneck) | `"detection_neck"` | Neck of the model |
+| [`EfficientBBoxHead`](../../nodes/README.md#efficientbboxhead) | `"detection_head"` | Head of the model |
+| [`AdaptiveDetectionLoss`](../../attached_modules/losses/README.md#adaptivedetectionloss) | `"detection_loss"` | Loss of the model |
+| [`MeanAveragePrecision`](../../attached_modules/metrics/README.md#meanaverageprecision) | `"detection_map"` | Main metric of the model |
+| [`BBoxVisualizer`](../../attached_modules/visualizers/README.md#bboxvisualizer) | `"detection_visualizer"` | Visualizer of the `detection_head` |
+
+**Parameters:**
+
+| Key | Type | Default value | Description |
+| ------------------- | --------------------------- | ---------------- | -------------------------------------------------------------------------------------------------- |
+| `variant` | `Literal["light", "heavy"]` | `"light"` | Defines the variant of the model. `"light"` uses `EfficientRep-N`, `"heavy"` uses `EfficientRep-L` |
+| `use_neck` | `bool` | `True` | Whether to include the neck in the model |
+| `backbone` | `str` | `"EfficientRep"` | Name of the node to be used as a backbone |
+| `backbone_params` | `dict` | `{}` | Additional parameters to the backbone |
+| `neck_params` | `dict` | `{}` | Additional parameters to the neck |
+| `head_params` | `dict` | `{}` | Additional parameters to the head |
+| `loss_params` | `dict` | `{}` | Additional parameters to the loss |
+| `visualizer_params` | `dict` | `{}` | Additional parameters to the visualizer |
+| `task_name` | `str \| None` | `None` | Custom task name for the head |
+
+## `KeypointDetectionModel`
+
+The `KeypointDetectionModel` allows for both `"light"` and `"heavy"` variants, where the `"heavy"` variant is more accurate, and the `"light"` variant is faster.
+
+See an example configuration file using this predefined model [here](../../../configs/keypoint_bbox_light_model.yaml) for the `"light"` variant, and [here](../../../configs/keypoint_bbox_heavy_model.yaml) for the `"heavy"` variant.
+
+**Components:**
+
+| Name | Alias | Function |
+| --------------------------------------------------------------------------------------------------------- | ---------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| [`EfficientRep`](../../nodes/README.md#efficientrep) | `"kpt_detection_backbone"` | Backbone of the model. Available variants: `"light"` (`EfficientRep-N`) and `"heavy"` (`EfficientRep-L`) |
+| [`RepPANNeck`](../../nodes/README.md#reppanneck) | `"kpt_detection_neck"` | Neck of the model |
+| [`EfficientKeypointBBoxHead`](../../nodes/README.md#efficientkeypointbboxhead) | `"kpt_detection_head"` | Head of the model |
+| [`EfficientKeypointBBoxLoss`](../../attached_modules/losses/README.md#efficientkeypointbboxloss) | `"kpt_detection_loss"` | Loss of the model |
+| [`ObjectKeypointSimilarity`](../../attached_modules/metrics/README.md#objectkeypointsimilarity) | `"kpt_detection_oks"` | Main metric of the model |
+| [`MeanAveragePrecisionKeypoints`](../../attached_modules/metrics/README.md#meanaverageprecisionkeypoints) | `"kpt_detection_map"` | Secondary metric of the model |
+| [`BBoxVisualizer`](../../attached_modules/visualizers/README.md#bboxvisualizer) | `"kpt_detection_visualizer"` | Visualizer for bounding boxes. Combined with keypoint visualizer using [`MultiVisualizer`](../../attached_modules/visualizers/README.md#multivisualizer) |
+| [`KeypointVisualizer`](../../attached_modules/visualizers/README.md#keypointvisualizer) | `"kpt_detection_visualizer"` | Visualizer for keypoints. Combined with keypoint visualizer using [`MultiVisualizer`](../../attached_modules/visualizers/README.md#multivisualizer) |
+
+**Parameters:**
+
+| Key | Type | Default value | Description |
+| ------------------------ | --------------------------- | ---------------- | -------------------------------------------------------------------------------------------------- |
+| `variant` | `Literal["light", "heavy"]` | `"light"` | Defines the variant of the model. `"light"` uses `EfficientRep-N`, `"heavy"` uses `EfficientRep-L` |
+| `use_neck` | `bool` | `True` | Whether to include the neck in the model |
+| `backbone` | `str` | `"EfficientRep"` | Name of the node to be used as a backbone |
+| `backbone_params` | `dict` | `{}` | Additional parameters to the backbone |
+| `neck_params` | `dict` | `{}` | Additional parameters to the neck |
+| `head_params` | `dict` | `{}` | Additional parameters to the head |
+| `loss_params` | `dict` | `{}` | Additional parameters to the loss |
+| `kpt_visualizer_params` | `dict` | `{}` | Additional parameters to the keypoint visualizer |
+| `bbox_visualizer_params` | `dict` | `{}` | Additional parameters to the bounding box visualizer |
+| `bbox_task_name` | `str \| None` | `None` | Custom task name for the detection head |
+| `kpt_task_name` | `str \| None` | `None` | Custom task name for the keypoint head |
+
+## `ClassificationModel`
+
+The `ClassificationModel` allows for both `"light"` and `"heavy"` variants, where the `"heavy"` variant is more accurate, and the `"light"` variant is faster. Can be used for multi-class and multi-label tasks.
+
+See an example configuration file using this predefined model [here](../../../configs/classification_light_model.yaml) for the `"light"` variant, and [here](../../../configs/classification_heavy_model.yaml) for the `"heavy"` variant.
+
+**Components:**
+
+| Name | Alias | Function |
+| ------------------------------------------------------------------------------ | --------------------------- | ------------------------------------------------------------------------------------------------------------ |
+| [`ResNet`](../../nodes/README.md#resnet) | `"classification_backbone"` | Backbone of the model. The `"light"` variant uses `ResNet-18`, while the `"heavy"` variant uses `ResNet-101` |
+| [`ClassificationHead`](../../nodes/README.md#classificationhead) | `"classification_head"` | Head of the model |
+| [`CrossEntropyLoss`](../../attached_modules/losses/README.md#crossentropyloss) | `"classification_loss"` | Loss of the model |
+| [F1Score](../../attached_modules/metrics/README.md#torchmetrics) | `"classification_f1_score"` | Main metric of the model |
+| [Accuracy](../../attached_modules/metrics/README.md#torchmetrics) | `"classification_accuracy"` | Secondary metric of the model |
+| [Recall](../../attached_modules/metrics/README.md#torchmetrics) | `"classification_recall"` | Secondary metric of the model |
+
+**Parameters:**
+
+| Key | Type | Default value | Description |
+| ------------------- | ------------------------------------- | -------------- | ----------------------------------------------------------------------------------------- |
+| `variant` | `Literal["light", "heavy"]` | `"light"` | Defines the variant of the model. `"light"` uses `ResNet-18`, `"heavy"` uses `ResNet-101` |
+| `backbone` | `str` | `"ResNet"` | Name of the node to be used as a backbone |
+| `backbone_params` | `dict` | `{}` | Additional parameters to the backbone |
+| `head_params` | `dict` | `{}` | Additional parameters to the head |
+| `loss_params` | `dict` | `{}` | Additional parameters to the loss |
+| `visualizer_params` | `dict` | `{}` | Additional parameters to the visualizer |
+| `task` | `Literal["multiclass", "multilabel"]` | `"multiclass"` | Type of the task of the model |
+| `task_name` | `str \| None` | `None` | Custom task name for the head |
diff --git a/luxonis_train/core/core.py b/luxonis_train/core/core.py
index bc3d3673..46405b0c 100644
--- a/luxonis_train/core/core.py
+++ b/luxonis_train/core/core.py
@@ -285,17 +285,22 @@ def thread_exception_hook(args):
def export(
self,
- onnx_save_path: str | None = None,
- *,
+ onnx_save_path: str | Path | None = None,
weights: str | Path | None = None,
) -> None:
"""Runs export.
- @type onnx_path: str | None
- @param onnx_path: Path to .onnx model. If not specified, model will be saved
- to export directory with name specified in config file.
-
- @raises RuntimeError: If `onnxsim` fails to simplify the model.
+ @type onnx_save_path: str | Path | None
+ @param onnx_save_path: Path to where the exported ONNX model will be saved.
+ If not specified, model will be saved to the export directory
+ with the name specified in config file.
+ @type weights: str | Path | None
+ @param weights: Path to the checkpoint from which to load weights.
+ If not specified, the value of `model.weights` from the
+ configuration file will be used. The current weights of the
+ model will be temporarily replaced with the weights from the
+ specified checkpoint.
+ @raises RuntimeError: If C{onnxsim} fails to simplify the model.
"""
weights = weights or self.cfg.model.weights
@@ -311,8 +316,8 @@ def export(
export_path = export_save_dir / (
self.cfg.exporter.name or self.cfg.model.name
)
- onnx_save_path = onnx_save_path or str(
- export_path.with_suffix(".onnx")
+ onnx_save_path = str(
+ onnx_save_path or export_path.with_suffix(".onnx")
)
with replace_weights(self.lightning_module, weights):
@@ -381,6 +386,7 @@ def test(
self,
new_thread: Literal[False] = ...,
view: Literal["train", "test", "val"] = "val",
+ weights: str | Path | None = ...,
) -> Mapping[str, float]: ...
@overload
@@ -388,6 +394,7 @@ def test(
self,
new_thread: Literal[True] = ...,
view: Literal["train", "test", "val"] = "val",
+ weights: str | Path | None = ...,
) -> None: ...
@typechecked
@@ -395,6 +402,7 @@ def test(
self,
new_thread: bool = False,
view: Literal["train", "val", "test"] = "val",
+ weights: str | Path | None = None,
) -> Mapping[str, float] | None:
"""Runs testing.
@@ -405,61 +413,78 @@ def test(
@rtype: Mapping[str, float] | None
@return: If new_thread is False, returns a dictionary test
results.
+ @type weights: str | Path | None
+ @param weights: Path to the checkpoint from which to load weights.
+ If not specified, the value of `model.weights` from the
+ configuration file will be used. The current weights of the
+ model will be temporarily replaced with the weights from the
+ specified checkpoint.
"""
+ weights = weights or self.cfg.model.weights
loader = self.pytorch_loaders[view]
- if not new_thread:
- return self.pl_trainer.test(self.lightning_module, loader)[0]
- else: # pragma: no cover
- self.thread = threading.Thread(
- target=self.pl_trainer.test,
- args=(self.lightning_module, loader),
- daemon=True,
- )
- self.thread.start()
+ with replace_weights(self.lightning_module, weights):
+ if not new_thread:
+ return self.pl_trainer.test(self.lightning_module, loader)[0]
+ else: # pragma: no cover
+ self.thread = threading.Thread(
+ target=self.pl_trainer.test,
+ args=(self.lightning_module, loader),
+ daemon=True,
+ )
+ self.thread.start()
@typechecked
def infer(
self,
view: Literal["train", "val", "test"] = "val",
save_dir: str | Path | None = None,
- source_path: str | None = None,
+ source_path: str | Path | None = None,
+ weights: str | Path | None = None,
) -> None:
"""Runs inference.
@type view: str
@param view: Which split to run the inference on. Valid values
- are: 'train', 'val', 'test'. Defaults to "val".
+ are: C{"train"}, C{"val"}, C{"test"}. Defaults to C{"val"}.
@type save_dir: str | Path | None
@param save_dir: Directory where to save the visualizations. If
not specified, visualizations will be rendered on the
screen.
- @type source_path: str | None
+ @type source_path: str | Path | None
@param source_path: Path to the image file, video file or directory.
If None, defaults to using dataset images.
+ @type weights: str | Path | None
+ @param weights: Path to the checkpoint from which to load weights.
+ If not specified, the value of `model.weights` from the
+ configuration file will be used. The current weights of the
+ model will be temporarily replaced with the weights from the
+ specified checkpoint.
"""
self.lightning_module.eval()
+ weights = weights or self.cfg.model.weights
- if source_path:
- source_path_obj = Path(source_path)
- if source_path_obj.suffix.lower() in VIDEO_FORMATS:
- process_video(self, source_path_obj, view, save_dir)
- elif source_path_obj.is_file():
- process_images(self, [source_path_obj], view, save_dir)
- elif source_path_obj.is_dir():
- image_files = [
- f
- for f in source_path_obj.iterdir()
- if f.suffix.lower() in IMAGE_FORMATS
- ]
- process_images(self, image_files, view, save_dir)
+ with replace_weights(self.lightning_module, weights):
+ if source_path:
+ source_path_obj = Path(source_path)
+ if source_path_obj.suffix.lower() in VIDEO_FORMATS:
+ process_video(self, source_path_obj, view, save_dir)
+ elif source_path_obj.is_file():
+ process_images(self, [source_path_obj], view, save_dir)
+ elif source_path_obj.is_dir():
+ image_files = [
+ f
+ for f in source_path_obj.iterdir()
+ if f.suffix.lower() in IMAGE_FORMATS
+ ]
+ process_images(self, image_files, view, save_dir)
+ else:
+ raise ValueError(
+ f"Source path {source_path} is not a valid file or directory."
+ )
else:
- raise ValueError(
- f"Source path {source_path} is not a valid file or directory."
- )
- else:
- process_dataset_images(self, view, save_dir)
+ process_dataset_images(self, view, save_dir)
def tune(self) -> None:
"""Runs Optuna tunning of hyperparameters."""
@@ -622,15 +647,30 @@ def _objective(trial: optuna.trial.Trial) -> float:
)
wandb_parent_tracker.log_hyperparams(study.best_params)
- def archive(self, path: str | Path | None = None) -> Path:
+ def archive(
+ self,
+ path: str | Path | None = None,
+ weights: str | Path | None = None,
+ ) -> Path:
"""Generates an NN Archive out of a model executable.
@type path: str | Path | None
@param path: Path to the model executable. If not specified, the
model will be exported first.
+ @type weights: str | Path | None
+ @param weights: Path to the checkpoint from which to load weights.
+ If not specified, the value of `model.weights` from the
+ configuration file will be used. The current weights of the
+ model will be temporarily replaced with the weights from the
+ specified checkpoint.
@rtype: Path
@return: Path to the generated NN Archive.
"""
+ weights = weights or self.cfg.model.weights
+ with replace_weights(self.lightning_module, weights):
+ return self._archive(path)
+
+ def _archive(self, path: str | Path | None = None) -> Path:
from .utils.archive_utils import get_heads, get_inputs, get_outputs
archive_name = self.cfg.archiver.name or self.cfg.model.name
diff --git a/luxonis_train/loaders/README.md b/luxonis_train/loaders/README.md
new file mode 100644
index 00000000..0a1a5bca
--- /dev/null
+++ b/luxonis_train/loaders/README.md
@@ -0,0 +1,34 @@
+# Loaders
+
+## Table Of Contents
+
+- [`LuxonisLoaderTorch`](#luxonisloadertorch)
+ - [Implementing a custom loader](#implementing-a-custom-loader)
+
+## `LuxonisLoaderTorch`
+
+The default loader used with `LuxonisTrain`. It can either load data from an already created dataset in the `LuxonisDataFormat` or create a new dataset automatically from a set of supported formats.
+
+**Parameters:**
+
+| Key | Type | Default value | Description |
+| ----------------- | --------------------------------------------------------------------------------------------------------- | ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `dataset_name` | `str` | `None` | Name of the dataset to load. If not provided, the `dataset_dir` must be provided instead |
+| `dataset_dir` | `str` | `None` | Path to the directory containing the dataset. If not provided, the `dataset_name` must be provided instead. Can be a path to a local directory or a URL. The data can be in a zip archive. New `LuxonisDataset` will be created using data from this directory and saved under the provided `dataset_name` |
+| `dataset_type` | `Literal["coco", "voc", "darknet", "yolov6", "yolov4", "createml", "tfcsv", "clsdir", "segmask"] \| None` | `None` | Type of the dataset. If not provided, the type will be inferred from the directory structure |
+| `team_id` | `str \| None` | `None` | Optional unique team identifier for the cloud |
+| `bucket_storage` | `Literal["local", "s3", "gcs"]` | `"local"` | Type of the bucket storage |
+| `delete_existing` | `bool` | `False` | Whether to delete the existing dataset with the same name. Only relevant if `dataset_dir` is provided. Use if you want to reparse the directory in case the data changed |
+
+### Implementing a custom loader
+
+To implement a loader, you need to create a class that inherits from `BaseLoaderTorch` and implement the following methods:
+
+- `input_shapes(self) -> dict[str, torch.Size]`: Returns a dictionary with input shapes for each input image.
+- `__len__(self) -> int`: Returns the number of samples in the dataset.
+- `__getitem__(self, idx: int) -> tuple[dict[str, torch.Tensor], dict[str, tuple[torch.Tensor, luxonis_train.enums.TaskType]]`: Returns a dictionary with input tensors for each input image.
+- `get_classes(self) -> dict[str, list[str]]`: Returns a dictionary with class names for each task in the dataset.
+
+For loaders yielding keypoint tasks, you also have to implement `get_n_keypoints(self) -> dict[str, int]` method.
+
+For more information, consult the in-source [documentation](https://github.com/luxonis/luxonis-train/blob/main/luxonis_train/loaders/base_loader.py).
diff --git a/luxonis_train/loaders/base_loader.py b/luxonis_train/loaders/base_loader.py
index 25ffc922..0c056d98 100644
--- a/luxonis_train/loaders/base_loader.py
+++ b/luxonis_train/loaders/base_loader.py
@@ -34,7 +34,7 @@ def __init__(
def image_source(self) -> str:
"""Name of the input image group.
- Example: 'image'
+ Example: C{"image"}
@type: str
"""
diff --git a/luxonis_train/loaders/luxonis_loader_torch.py b/luxonis_train/loaders/luxonis_loader_torch.py
index b0e83a94..230128b5 100644
--- a/luxonis_train/loaders/luxonis_loader_torch.py
+++ b/luxonis_train/loaders/luxonis_loader_torch.py
@@ -1,5 +1,4 @@
import logging
-from pathlib import Path
from typing import Literal
import numpy as np
@@ -136,9 +135,7 @@ def _parse_dataset(
delete_existing: bool,
) -> LuxonisDataset:
if dataset_name is None:
- dataset_name = Path(dataset_dir).stem
- if dataset_type is not None:
- dataset_name += f"_{dataset_type.value}"
+ dataset_name = dataset_dir.split("/")[-1]
if LuxonisDataset.exists(dataset_name):
if not delete_existing:
diff --git a/luxonis_train/models/luxonis_lightning.py b/luxonis_train/models/luxonis_lightning.py
index 03d633c9..459b20d1 100644
--- a/luxonis_train/models/luxonis_lightning.py
+++ b/luxonis_train/models/luxonis_lightning.py
@@ -803,7 +803,7 @@ def configure_optimizers(
self,
) -> tuple[
list[torch.optim.Optimizer],
- list[torch.optim.lr_scheduler._LRScheduler],
+ list[torch.optim.lr_scheduler.LRScheduler],
]:
"""Configures model optimizers and schedulers."""
cfg_optimizer = self.cfg.trainer.optimizer
diff --git a/luxonis_train/nodes/README.md b/luxonis_train/nodes/README.md
index dad43921..9e561cc8 100644
--- a/luxonis_train/nodes/README.md
+++ b/luxonis_train/nodes/README.md
@@ -1,235 +1,235 @@
# Nodes
Nodes are the basic building structures of the model. They can be connected together
-arbitrarily as long as the two nodes are compatible with each other. We've grouped together nodes that are similar so it's easier to build an architecture that makes sense.
+arbitrarily as long as the two nodes are compatible with each other. We've grouped together nodes that are similar, so it's easier to build an architecture that makes sense.
## Table Of Contents
- [Backbones](#backbones)
- - [ResNet](#resnet)
- - [MicroNet](#micronet)
- - [RepVGG](#repvgg)
- - [EfficientRep](#efficientrep)
- - [RexNetV1_lite](#rexnetv1_lite)
- - [MobileOne](#mobileone)
- - [MobileNetV2](#mobilenetv2)
- - [EfficientNet](#efficientnet)
- - [ContextSpatial](#contextspatial)
- - [DDRNet](#ddrnet)
+ - [`ResNet`](#resnet)
+ - [`MicroNet`](#micronet)
+ - [`RepVGG`](#repvgg)
+ - [`EfficientRep`](#efficientrep)
+ - [`RexNetV1_lite`](#rexnetv1_lite)
+ - [`MobileOne`](#mobileone)
+ - [`MobileNetV2`](#mobilenetv2)
+ - [`EfficientNet`](#efficientnet)
+ - [`ContextSpatial`](#contextspatial)
+ - [`DDRNet`](#ddrnet)
- [Necks](#necks)
- - [RepPANNeck](#reppanneck)
+ - [`RepPANNeck`](#reppanneck)
- [Heads](#heads)
- - [ClassificationHead](#classificationhead)
- - [SegmentationHead](#segmentationhead)
- - [BiSeNetHead](#bisenethead)
- - [EfficientBBoxHead](#efficientbboxhead)
- - [EfficientKeypointBBoxHead](#efficientkeypointbboxhead)
- - [DDRNetSegmentationHead](#ddrnetsegmentationhead)
+ - [`ClassificationHead`](#classificationhead)
+ - [`SegmentationHead`](#segmentationhead)
+ - [`BiSeNetHead`](#bisenethead)
+ - [`EfficientBBoxHead`](#efficientbboxhead)
+ - [`EfficientKeypointBBoxHead`](#efficientkeypointbboxhead)
+ - [`DDRNetSegmentationHead`](#ddrnetsegmentationhead)
Every node takes these parameters:
-| Key | Type | Default value | Description |
-| ---------------- | ----------- | ------------- | ---------------------------------------------------------------------------- |
-| n_classes | int \| None | None | Number of classes in the dataset. Inferred from the dataset if not provided. |
-| remove_on_export | bool | False | Whether node should be removed when exporting the whole model. |
+| Key | Type | Default value | Description |
+| ------------------ | ------------- | ------------- | --------------------------------------------------------------------------- |
+| `n_classes` | `int \| None` | `None` | Number of classes in the dataset. Inferred from the dataset if not provided |
+| `remove_on_export` | `bool` | `False` | Whether node should be removed when exporting the whole model |
-In addition, the following class attributes can be overriden:
+In addition, the following class attributes can be overridden:
-| Key | Type | Default value | Description |
-| ------------ | ------------------------------------------------------------------- | ------------- | --------------------------------------------------------------------------------------------------------------------------------------------- |
-| attach_index | int \| "all" \| Tuple\[int, int\] \| Tuple\[int, int, int\] \| None | None | Index of previous output that the head attaches to. Each node has a sensible default. Usually should not be manually set in most cases. |
-| tasks | List\[TaskType\] \| Dict\[TaskType, str\] \| None | None | Tasks supported by the node. Should be overriden for head nodes. Either a list of tasks or a dictionary mapping tasks to their default names. |
+| Key | Type | Default value | Description |
+| -------------- | ----------------------------------------------------------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `attach_index` | `int \| "all" \| tuple[int, int] \| tuple[int, int, int] \| None` | `None` | Index of previous output that the head attaches to. Each node has a sensible default. Usually should not be manually set in most cases. Can be either a single index, a slice (negative indexing is also supported), or `"all"` |
+| `tasks` | `list[TaskType] \| Dict[TaskType, str] \| None` | `None` | Tasks supported by the node. Should be overridden for head nodes. Either a list of tasks or a dictionary mapping tasks to their default names |
Additional parameters for specific nodes are listed below.
## Backbones
-### ResNet
+### `ResNet`
Adapted from [here](https://pytorch.org/vision/main/models/resnet.html).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ---------------- | ----------------------------------------- | ------------- | -------------------------------------- |
-| variant | Literal\["18", "34", "50", "101", "152"\] | "18" | Variant of the network. |
-| download_weights | bool | False | If True download weights from imagenet |
+| Key | Type | Default value | Description |
+| ------------------ | ----------------------------------------- | ------------- | -------------------------------------- |
+| `variant` | `Literal["18", "34", "50", "101", "152"]` | `"18"` | Variant of the network |
+| `download_weights` | `bool` | `False` | If True download weights from ImageNet |
-### MicroNet
+### `MicroNet`
Adapted from [here](https://github.com/liyunsheng13/micronet).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ------- | --------------------------- | ------------- | ----------------------- |
-| variant | Literal\["M1", "M2", "M3"\] | "M1" | Variant of the network. |
+| Key | Type | Default value | Description |
+| --------- | --------------------------- | ------------- | ---------------------- |
+| `variant` | `Literal["M1", "M2", "M3"]` | `"M1"` | Variant of the network |
-### RepVGG
+### `RepVGG`
Adapted from [here](https://github.com/DingXiaoH/RepVGG).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ------- | --------------------------- | ------------- | ----------------------- |
-| variant | Literal\["A0", "A1", "A2"\] | "A0" | Variant of the network. |
+| Key | Type | Default value | Description |
+| --------- | --------------------------- | ------------- | ---------------------- |
+| `variant` | `Literal["A0", "A1", "A2"]` | `"A0"` | Variant of the network |
-### EfficientRep
+### `EfficientRep`
Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ------------- | ----------------------------------------------------------------- | --------------------------- | --------------------------------------------------------------- |
-| variant | Literal\["n", "nano", "s", "small", "m", "medium", "l", "large"\] | "nano" | Variant of the network |
-| channels_list | List\[int\] | \[64, 128, 256, 512, 1024\] | List of number of channels for each block |
-| n_repeats | List\[int\] | \[1, 6, 12, 18, 6\] | List of number of repeats of RepVGGBlock |
-| depth_mul | float | 0.33 | Depth multiplier |
-| width_mul | float | 0.25 | Width multiplier |
-| block | Literal\["RepBlock", "CSPStackRepBlock"\] | "RepBlock" | Base block used |
-| csp_e | float | 0.5 | Factor for intermediate channels when block=="CSPStackRepBlock" |
+| Key | Type | Default value | Description |
+| --------------- | ----------------------------------------------------------------- | --------------------------- | -------------------------------------------------------------------------- |
+| `variant` | `Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]` | `"nano"` | Variant of the network |
+| `channels_list` | `list[int]` | \[64, 128, 256, 512, 1024\] | List of number of channels for each block |
+| `n_repeats` | `list[int]` | \[1, 6, 12, 18, 6\] | List of number of repeats of `RepVGGBlock` |
+| `depth_mul` | `float` | `0.33` | Depth multiplier |
+| `width_mul` | `float` | `0.25` | Width multiplier |
+| `block` | `Literal["RepBlock", "CSPStackRepBlock"]` | `"RepBlock"` | Base block used |
+| `csp_e` | `float` | `0.5` | Factor for intermediate channels when block is set to `"CSPStackRepBlock"` |
### RexNetV1_lite
-Adapted from ([here](https://github.com/clovaai/rexnet).
+Adapted from [here](https://github.com/clovaai/rexnet)
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| --------------- | ------------------ | ------------- | ----------------------------- |
-| fix_head_stem | bool | False | Whether to multiply head stem |
-| divisible_value | int | 8 | Divisor used |
-| input_ch | int | 16 | tarting channel dimension |
-| final_ch | int | 164 | Final channel dimension |
-| multiplier | float | 1.0 | Channel dimension multiplier |
-| kernel_sizes | int \| list\[int\] | 3 | Kernel sizes |
+| Key | Type | Default value | Description |
+| ----------------- | ------------------ | ------------- | ----------------------------- |
+| `fix_head_stem` | `bool` | `False` | Whether to multiply head stem |
+| `divisible_value` | `int` | `8` | Divisor used |
+| `input_ch` | `int` | `16` | tarting channel dimension |
+| `final_ch` | `int` | `164` | Final channel dimension |
+| `multiplier` | `float` | `1.0` | Channel dimension multiplier |
+| `kernel_sizes` | `int \| list[int]` | `3` | Kernel sizes |
-### MobileOne
+### `MobileOne`
Adapted from [here](https://github.com/apple/ml-mobileone).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ------- | --------------------------------------- | ------------- | ----------------------- |
-| variant | Literal\["s0", "s1", "s2", "s3", "s4"\] | "s0" | Variant of the network. |
+| Key | Type | Default value | Description |
+| --------- | --------------------------------------- | ------------- | ---------------------- |
+| `variant` | `Literal["s0", "s1", "s2", "s3", "s4"]` | `"s0"` | Variant of the network |
-### MobileNetV2
+### `MobileNetV2`
Adapted from [here](https://pytorch.org/vision/main/models/generated/torchvision.models.mobilenet_v2.html).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ---------------- | ---- | ------------- | -------------------------------------- |
-| download_weights | bool | False | If True download weights from imagenet |
+| Key | Type | Default value | Description |
+| ------------------ | ------ | ------------- | -------------------------------------- |
+| `download_weights` | `bool` | `False` | If True download weights from ImageNet |
-### EfficientNet
+### `EfficientNet`
Adapted from [here](https://github.com/rwightman/gen-efficientnet-pytorch).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ---------------- | ---- | ------------- | --------------------------------------- |
-| download_weights | bool | False | If True download weights from imagenet. |
+| Key | Type | Default value | Description |
+| ------------------ | ------ | ------------- | -------------------------------------- |
+| `download_weights` | `bool` | `False` | If True download weights from ImageNet |
-### ContextSpatial
+### `ContextSpatial`
Adapted from [here](https://github.com/taveraantonio/BiseNetv1).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ---------------- | ---- | ------------- | ------------- |
-| context_backbone | str | "MobileNetV2" | Backbone used |
+| Key | Type | Default value | Description |
+| ------------------ | ----- | --------------- | ---------------------------------------------------------------------------------------------------- |
+| `context_backbone` | `str` | `"MobileNetV2"` | Backbone used for the context path. Must be a reference to a node registered in the `NODES` registry |
-### DDRNet
+### `DDRNet`
Adapted from [here](https://github.com/ydhongHIT/DDRNet)
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ------- | -------------------------- | ------------- | ----------------------- |
-| variant | Literal\["23-slim", "23"\] | "23-slim" | Variant of the network. |
+| Key | Type | Default value | Description |
+| --------- | -------------------------- | ------------- | ---------------------- |
+| `variant` | `Literal["23-slim", "23"]` | `"23-slim"` | Variant of the network |
## Neck
-### RepPANNeck
+### `RepPANNeck`
Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ------------- | ----------------------------------------------------------------- | ------------------------------------------------------- | --------------------------------------------------------------- |
-| variant | Literal\["n", "nano", "s", "small", "m", "medium", "l", "large"\] | "nano" | Variant of the network |
-| n_heads | Literal\[2,3,4\] | 3 ***Note:** Should be same also on head in most cases* | Number of output heads |
-| channels_list | List\[int\] | \[256, 128, 128, 256, 256, 512\] | List of number of channels for each block |
-| n_repeats | List\[int\] | \[12, 12, 12, 12\] | List of number of repeats of RepVGGBlock |
-| depth_mul | float | 0.33 | Depth multiplier |
-| width_mul | float | 0.25 | Width multiplier |
-| block | Literal\["RepBlock", "CSPStackRepBlock"\] | "RepBlock" | Base block used |
-| csp_e | float | 0.5 | Factor for intermediate channels when block=="CSPStackRepBlock" |
+| Key | Type | Default value | Description |
+| --------------- | ----------------------------------------------------------------- | -------------------------------- | ------------------------------------------------------------------------------- |
+| `variant` | `Literal["n", "nano", "s", "small", "m", "medium", "l", "large"]` | `"nano"` | Variant of the network |
+| `n_heads` | `Literal[2,3,4]` | `3` | Number of output heads. Should be same also on the connected head in most cases |
+| `channels_list` | `list[int]` | `[256, 128, 128, 256, 256, 512]` | List of number of channels for each block |
+| `n_repeats` | `list[int]` | `[12, 12, 12, 12]` | List of number of repeats of `RepVGGBlock` |
+| `depth_mul` | `float` | `0.33` | Depth multiplier |
+| `width_mul` | `float` | `0.25` | Width multiplier |
+| `block` | `Literal["RepBlock", "CSPStackRepBlock"]` | `"RepBlock"` | Base block used |
+| `csp_e` | `float` | `0.5` | Factor for intermediate channels when block is set to `"CSPStackRepBlock"` |
## Heads
-### ClassificationHead
+### `ClassificationHead`
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ---------- | ----- | ------------- | --------------------------------------------- |
-| fc_dropout | float | 0.2 | Dropout rate before last layer, range \[0,1\] |
+| Key | Type | Default value | Description |
+| ------------ | ------- | ------------- | ------------------------------------------------ |
+| `fc_dropout` | `float` | `0.2` | Dropout rate before last layer, range $\[0, 1\]$ |
-### SegmentationHead
+### `SegmentationHead`
Adapted from [here](https://github.com/pytorch/vision/blob/main/torchvision/models/segmentation/fcn.py).
-### BiSeNetHead
+### `BiSeNetHead`
Adapted from [here](https://github.com/taveraantonio/BiseNetv1).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| --------------------- | ---- | ------------- | -------------------------------------- |
-| intermediate_channels | int | 64 | How many intermediate channels to use. |
+| Key | Type | Default value | Description |
+| ----------------------- | ----- | ------------- | ------------------------------------- |
+| `intermediate_channels` | `int` | `64` | How many intermediate channels to use |
-### EfficientBBoxHead
+### `EfficientBBoxHead`
Adapted from [here](https://arxiv.org/pdf/2209.02976.pdf).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ---------- | ----- | ------------- | -------------------------------------------------- |
-| n_heads | bool | 3 | Number of output heads |
-| conf_thres | float | 0.25 | Confidence threshold for nms (used for evaluation) |
-| iou_thres | float | 0.45 | Iou threshold for nms (used for evaluation) |
+| Key | Type | Default value | Description |
+| ------------ | ------- | ------------- | --------------------------------------------------------------------- |
+| `n_heads` | `bool` | `3` | Number of output heads |
+| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) |
+| `iou_thres` | `float` | `0.45` | `IoU` threshold for non-maxima-suppression (used for evaluation) |
-### EfficientKeypointBBoxHead
+### `EfficientKeypointBBoxHead`
Adapted from [here](https://arxiv.org/pdf/2207.02696.pdf).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| ----------- | ----------- | ------------- | -------------------------------------------------- |
-| n_keypoints | int \| None | None | Number of keypoints. |
-| n_heads | int | 3 | Number of output heads |
-| conf_thres | float | 0.25 | Confidence threshold for nms (used for evaluation) |
-| iou_thres | float | 0.45 | Iou threshold for nms (used for evaluation) |
+| Key | Type | Default value | Description |
+| ------------- | -------------- | ------------- | --------------------------------------------------------------------- |
+| `n_keypoints` | `int \| None ` | `None` | Number of keypoints |
+| `n_heads` | `int` | `3` | Number of output heads |
+| `conf_thres` | `float` | `0.25` | Confidence threshold for non-maxima-suppression (used for evaluation) |
+| `iou_thres` | `float` | `0.45` | `IoU` threshold for non-maxima-suppression (used for evaluation) |
-### DDRNetSegmentationHead
+### `DDRNetSegmentationHead`
Adapted from [here](https://github.com/ydhongHIT/DDRNet).
-**Params**
+**Parameters:**
-| Key | Type | Default value | Description |
-| -------------- | ---- | ------------- | ---------------------------------------------------------------------------------------------- |
-| inter_channels | int | 64 | Width of internal conv. Must be a multiple of scale_factor^2 when inter_mode is pixel_shuffle. |
-| inter_mode | str | "bilinear | Upsampling method. |
+| Key | Type | Default value | Description |
+| ---------------- | ----- | ------------- | ------------------------------------------------------------------------------------------------------------------------- |
+| `inter_channels` | `int` | `64` | Width of internal convolutions |
+| `inter_mode` | `str` | `"bilinear"` | Up-sampling method. One of `"nearest"`, `"linear"`, `"bilinear"`, `"bicubic"`, `"trilinear"`, `"area"`, `"pixel_shuffle"` |
diff --git a/luxonis_train/utils/registry.py b/luxonis_train/utils/registry.py
index 57ca0066..8044f13c 100644
--- a/luxonis_train/utils/registry.py
+++ b/luxonis_train/utils/registry.py
@@ -3,7 +3,7 @@
import lightning.pytorch as pl
from luxonis_ml.utils.registry import Registry
-from torch.optim.lr_scheduler import _LRScheduler
+from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
import luxonis_train as lt
@@ -32,7 +32,7 @@
OPTIMIZERS: Registry[type[Optimizer]] = Registry(name="optimizers")
"""Registry for all optimizers."""
-SCHEDULERS: Registry[type[_LRScheduler]] = Registry(name="schedulers")
+SCHEDULERS: Registry[type[LRScheduler]] = Registry(name="schedulers")
"""Registry for all schedulers."""
VISUALIZERS: Registry[type["lt.visualizers.BaseVisualizer"]] = Registry(
diff --git a/media/coverage_badge.svg b/media/coverage_badge.svg
deleted file mode 100644
index ee07d4c2..00000000
--- a/media/coverage_badge.svg
+++ /dev/null
@@ -1,21 +0,0 @@
-
-
diff --git a/media/pybadge.svg b/media/pybadge.svg
deleted file mode 100644
index 983d6f42..00000000
--- a/media/pybadge.svg
+++ /dev/null
@@ -1 +0,0 @@
-
\ No newline at end of file
diff --git a/pyproject.toml b/pyproject.toml
index 67dc3d16..39c11a92 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,6 +12,8 @@ classifiers = [
"License :: OSI Approved :: Apache Software License",
"Development Status :: 4 - Beta",
"Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Scientific/Engineering :: Image Processing",
"Topic :: Scientific/Engineering :: Image Recognition",
@@ -47,6 +49,7 @@ select = ["E4", "E7", "E9", "F", "W", "B", "I"]
[tool.docformatter]
black = true
+style = "epytext"
wrap-summaries = 72
wrap-descriptions = 72
diff --git a/tests/configs/ddrnet.yaml b/tests/configs/ddrnet.yaml
index e5c7ea9f..542fc0f6 100644
--- a/tests/configs/ddrnet.yaml
+++ b/tests/configs/ddrnet.yaml
@@ -21,12 +21,12 @@ model:
name: CrossEntropyLoss
trainer:
preprocessing:
- train_image_size:
- - &height 128
- - &width 128
- keep_aspect_ratio: False
+ train_image_size:
+ - 128
+ - 128
+ keep_aspect_ratio: false
normalize:
- active: True
+ active: true
batch_size: 2
epochs: &epochs 1
diff --git a/tests/configs/parking_lot_config.yaml b/tests/configs/parking_lot_config.yaml
index 78711178..5cda65c1 100644
--- a/tests/configs/parking_lot_config.yaml
+++ b/tests/configs/parking_lot_config.yaml
@@ -104,7 +104,7 @@ model:
tracker:
project_name: Parking_Lot
- is_tensorboard: True
+ is_tensorboard: true
loader:
train_view: val
@@ -118,23 +118,23 @@ trainer:
n_sanity_val_steps: 1
profiler: null
- verbose: True
+ verbose: true
batch_size: 2
accumulate_grad_batches: 1
epochs: 200
n_workers: 8
validation_interval: 10
n_log_images: 8
- skip_last_batch: True
- log_sub_losses: True
+ skip_last_batch: true
+ log_sub_losses: true
save_top_k: 3
preprocessing:
train_image_size: [256, 320]
- keep_aspect_ratio: False
- train_rgb: True
+ keep_aspect_ratio: false
+ train_rgb: true
normalize:
- active: True
+ active: true
augmentations:
- name: Defocus
params:
diff --git a/tests/configs/segmentation_parse_loader.yaml b/tests/configs/segmentation_parse_loader.yaml
index 14814571..178a89cb 100644
--- a/tests/configs/segmentation_parse_loader.yaml
+++ b/tests/configs/segmentation_parse_loader.yaml
@@ -16,9 +16,9 @@ loader:
trainer:
preprocessing:
train_image_size: [&height 128, &width 128]
- keep_aspect_ratio: False
+ keep_aspect_ratio: false
normalize:
- active: True
+ active: true
batch_size: 4
epochs: &epochs 1