Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Added new rule in config auto-population and updated documentation. #140

Merged
merged 5 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions configs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,7 @@ Here you can change everything related to actual training of the model.
| `save_top_k` | `-1 \| NonNegativeInt` | `3` | Save top K checkpoints based on validation loss when training |
| `smart_cfg_auto_populate` | `bool` | `True` | Automatically populate sensible default values for missing config fields and log warnings |
| `n_validation_batches` | `PositiveInt \| None` | `None` | Limits the number of validation/test batches and makes the val/test loaders deterministic |

**Example:**
| `smart_cfg_auto_populate` | `bool` | `True` | Automatically populate sensible default values for missing config fields and log warnings |
JSabadin marked this conversation as resolved.
Show resolved Hide resolved

```yaml

Expand All @@ -247,8 +246,33 @@ trainer:
skip_last_batch: true
log_sub_losses: true
save_top_k: 3
smart_cfg_auto_populate: true
```

### Smart Configuration Auto-population

When setting `trainer.smart_cfg_auto_populate = True`, the following set of rules will be applied automatically to populate missing configuration fields with sensible defaults:

#### Auto-population Rules

1. **Default Optimizer and Scheduler:**

- If `training_strategy` is not defined and neither `optimizer` nor `scheduler` is set, the following defaults are applied:
- Optimizer: `Adam`
- Scheduler: `ConstantLR`

1. **CosineAnnealingLR Adjustment:**

- If the `CosineAnnealingLR` scheduler is used and `T_max` is not set, it is automatically set to the number of epochs.

1. **Mosaic4 Augmentation:**

- If `Mosaic4` augmentation is used without `out_width` and `out_height` parameters, they are set to match the training image size.

1. **Validation/Test Views:**

- If `train_view`, `val_view`, and `test_view` are the same, and `n_validation_batches` is not explicitly set, it defaults to `10` to prevent validation/testing on the entire training set.

### Preprocessing

We use [`Albumentations`](https://albumentations.ai/docs/) library for `augmentations`. [Here](https://albumentations.ai/docs/api_reference/full_reference/#pixel-level-transforms) you can see a list of all pixel level augmentations supported, and [here](https://albumentations.ai/docs/api_reference/full_reference/#spatial-level-transforms) you see all spatial level transformations. In the configuration you can specify any augmentation from these lists and their parameters.
Expand Down
13 changes: 13 additions & 0 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,19 @@ def smart_auto_populate(cls, instance: "Config") -> None:
"`Mosaic4` augmentation detected. Automatically set `out_width` and `out_height` to match `train_image_size`."
)

# Rule: If train, val, and test views are the same, set n_validation_batches
if (
instance.loader.train_view
== instance.loader.val_view
== instance.loader.test_view
and instance.trainer.n_validation_batches is None
):
instance.trainer.n_validation_batches = 10
logger.warning(
"Train, validation, and test views are the same. Automatically set `n_validation_batches` to 10 to prevent validation/testing on the full train set. "
"If this behavior is not desired, set `smart_cfg_auto_populate` to `False`."
)


def is_acyclic(graph: dict[str, list[str]]) -> bool:
"""Tests if graph is acyclic.
Expand Down
Loading