diff --git a/configs/resnet_model.yaml b/configs/resnet_model.yaml new file mode 100644 index 00000000..7e93d269 --- /dev/null +++ b/configs/resnet_model.yaml @@ -0,0 +1,57 @@ + +model: + name: resnet50_classification + nodes: + - name: ResNet + variant: "50" + download_weights: True + + - name: ClassificationHead + inputs: + - ResNet + + losses: + - name: CrossEntropyLoss + attached_to: ClassificationHead + + metrics: + - name: Accuracy + is_main_metric: true + attached_to: ClassificationHead + + visualizers: + - name: ClassificationVisualizer + attached_to: ClassificationHead + params: + font_scale: 0.5 + color: [255, 0, 0] + thickness: 2 + include_plot: True + +dataset: + name: cifar10_test + +trainer: + batch_size: 4 + epochs: &epochs 200 + num_workers: 4 + validation_interval: 10 + num_log_images: 8 + + preprocessing: + train_image_size: [&height 224, &width 224] + keep_aspect_ratio: False + normalize: + active: True + + callbacks: + - name: ExportOnTrainEnd + - name: TestOnTrainEnd + + optimizer: + name: SGD + params: + lr: 0.02 + + scheduler: + name: ConstantLR diff --git a/luxonis_train/nodes/README.md b/luxonis_train/nodes/README.md index bd44ac5a..637c5026 100644 --- a/luxonis_train/nodes/README.md +++ b/luxonis_train/nodes/README.md @@ -5,7 +5,7 @@ arbitrarily as long as the two nodes are compatible with each other. ## Table Of Contents -- [ResNet18](#resnet18) +- [ResNet](#resnet) - [MicroNet](#micronet) - [RepVGG](#repvgg) - [EfficientRep](#efficientrep) @@ -30,15 +30,16 @@ Every node takes these parameters: Additional parameters for specific nodes are listed below. -## ResNet18 +## ResNet -Adapted from [here](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html). +Adapted from [here](https://pytorch.org/vision/main/models/resnet.html). **Params** -| Key | Type | Default value | Description | -| ---------------- | ---- | ------------- | -------------------------------------- | -| download_weights | bool | False | If True download weights from imagenet | +| Key | Type | Default value | Description | +| ---------------- | ----------------------------------------- | ------------- | -------------------------------------- | +| variant | Literal\["18", "34", "50", "101", "152"\] | "18" | Variant of the network. | +| download_weights | bool | False | If True download weights from imagenet | ## MicroNet diff --git a/luxonis_train/nodes/__init__.py b/luxonis_train/nodes/__init__.py index d7ec70d0..954db2be 100644 --- a/luxonis_train/nodes/__init__.py +++ b/luxonis_train/nodes/__init__.py @@ -10,7 +10,7 @@ from .mobileone import MobileOne from .reppan_neck import RepPANNeck from .repvgg import RepVGG -from .resnet18 import ResNet18 +from .resnet import ResNet from .rexnetv1 import ReXNetV1_lite from .segmentation_head import SegmentationHead @@ -28,6 +28,6 @@ "ReXNetV1_lite", "RepPANNeck", "RepVGG", - "ResNet18", + "ResNet", "SegmentationHead", ] diff --git a/luxonis_train/nodes/resnet18.py b/luxonis_train/nodes/resnet.py similarity index 61% rename from luxonis_train/nodes/resnet18.py rename to luxonis_train/nodes/resnet.py index 9c38681a..14ff8066 100644 --- a/luxonis_train/nodes/resnet18.py +++ b/luxonis_train/nodes/resnet.py @@ -1,10 +1,9 @@ -"""ResNet18 backbone. +"""ResNet backbone. -Source: U{https://pytorch.org/vision/main/models/generated/ -torchvision.models.resnet18.html} +Source: U{https://pytorch.org/vision/main/models/resnet.html} @license: U{PyTorch} """ - +from typing import Literal import torchvision from torch import Tensor @@ -12,19 +11,22 @@ from .base_node import BaseNode -class ResNet18(BaseNode[Tensor, list[Tensor]]): +class ResNet(BaseNode[Tensor, list[Tensor]]): attach_index: int = -1 def __init__( self, + variant: Literal["18", "34", "50", "101", "152"] = "18", channels_list: list[int] | None = None, download_weights: bool = False, **kwargs, ): - """Implementation of the ResNet18 backbone. + """Implementation of the ResNetX backbone. TODO: add more info + @type variant: Literal["18", "34", "50", "101", "152"] + @param variant: ResNet variant. Defaults to "18". @type channels_list: list[int] | None @param channels_list: List of channels to return. If unset, defaults to [64, 128, 256, 512]. @@ -35,7 +37,12 @@ def __init__( """ super().__init__(**kwargs) - self.backbone = torchvision.models.resnet18( + if variant not in RESNET_VARIANTS: + raise ValueError( + f"ResNet model variant should be in {list(RESNET_VARIANTS.keys())}" + ) + + self.backbone = RESNET_VARIANTS[variant]( weights="DEFAULT" if download_weights else None ) self.channels_list = channels_list or [64, 128, 256, 512] @@ -57,3 +64,12 @@ def forward(self, x: Tensor) -> list[Tensor]: outs.append(x) return outs + + +RESNET_VARIANTS = { + "18": torchvision.models.resnet18, + "34": torchvision.models.resnet34, + "50": torchvision.models.resnet50, + "101": torchvision.models.resnet101, + "152": torchvision.models.resnet152, +} diff --git a/media/coverage_badge.svg b/media/coverage_badge.svg index 12876e69..4033e89e 100644 --- a/media/coverage_badge.svg +++ b/media/coverage_badge.svg @@ -15,7 +15,7 @@ coverage coverage - 78% - 78% + 79% + 79%