Skip to content

Commit

Permalink
Resnet Variants (#9)
Browse files Browse the repository at this point in the history
* Added ResNet variants

* ResNet50 example

* Simplified example

* fixed resnet config

* [Automated] Updated coverage badge

---------

Co-authored-by: Martin Kozlovsky <[email protected]>
Co-authored-by: GitHub Actions <[email protected]>
  • Loading branch information
3 people committed Oct 9, 2024
1 parent 2e82131 commit afade1f
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 17 deletions.
57 changes: 57 additions & 0 deletions configs/resnet_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

model:
name: resnet50_classification
nodes:
- name: ResNet
variant: "50"
download_weights: True

- name: ClassificationHead
inputs:
- ResNet

losses:
- name: CrossEntropyLoss
attached_to: ClassificationHead

metrics:
- name: Accuracy
is_main_metric: true
attached_to: ClassificationHead

visualizers:
- name: ClassificationVisualizer
attached_to: ClassificationHead
params:
font_scale: 0.5
color: [255, 0, 0]
thickness: 2
include_plot: True

dataset:
name: cifar10_test

trainer:
batch_size: 4
epochs: &epochs 200
num_workers: 4
validation_interval: 10
num_log_images: 8

preprocessing:
train_image_size: [&height 224, &width 224]
keep_aspect_ratio: False
normalize:
active: True

callbacks:
- name: ExportOnTrainEnd
- name: TestOnTrainEnd

optimizer:
name: SGD
params:
lr: 0.02

scheduler:
name: ConstantLR
13 changes: 7 additions & 6 deletions luxonis_train/nodes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ arbitrarily as long as the two nodes are compatible with each other.

## Table Of Contents

- [ResNet18](#resnet18)
- [ResNet](#resnet)
- [MicroNet](#micronet)
- [RepVGG](#repvgg)
- [EfficientRep](#efficientrep)
Expand All @@ -30,15 +30,16 @@ Every node takes these parameters:

Additional parameters for specific nodes are listed below.

## ResNet18
## ResNet

Adapted from [here](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html).
Adapted from [here](https://pytorch.org/vision/main/models/resnet.html).

**Params**

| Key | Type | Default value | Description |
| ---------------- | ---- | ------------- | -------------------------------------- |
| download_weights | bool | False | If True download weights from imagenet |
| Key | Type | Default value | Description |
| ---------------- | ----------------------------------------- | ------------- | -------------------------------------- |
| variant | Literal\["18", "34", "50", "101", "152"\] | "18" | Variant of the network. |
| download_weights | bool | False | If True download weights from imagenet |

## MicroNet

Expand Down
4 changes: 2 additions & 2 deletions luxonis_train/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .mobileone import MobileOne
from .reppan_neck import RepPANNeck
from .repvgg import RepVGG
from .resnet18 import ResNet18
from .resnet import ResNet
from .rexnetv1 import ReXNetV1_lite
from .segmentation_head import SegmentationHead

Expand All @@ -28,6 +28,6 @@
"ReXNetV1_lite",
"RepPANNeck",
"RepVGG",
"ResNet18",
"ResNet",
"SegmentationHead",
]
30 changes: 23 additions & 7 deletions luxonis_train/nodes/resnet18.py → luxonis_train/nodes/resnet.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
"""ResNet18 backbone.
"""ResNet backbone.
Source: U{https://pytorch.org/vision/main/models/generated/
torchvision.models.resnet18.html}
Source: U{https://pytorch.org/vision/main/models/resnet.html}
@license: U{PyTorch<https://github.com/pytorch/pytorch/blob/master/LICENSE>}
"""

from typing import Literal

import torchvision
from torch import Tensor

from .base_node import BaseNode


class ResNet18(BaseNode[Tensor, list[Tensor]]):
class ResNet(BaseNode[Tensor, list[Tensor]]):
attach_index: int = -1

def __init__(
self,
variant: Literal["18", "34", "50", "101", "152"] = "18",
channels_list: list[int] | None = None,
download_weights: bool = False,
**kwargs,
):
"""Implementation of the ResNet18 backbone.
"""Implementation of the ResNetX backbone.
TODO: add more info
@type variant: Literal["18", "34", "50", "101", "152"]
@param variant: ResNet variant. Defaults to "18".
@type channels_list: list[int] | None
@param channels_list: List of channels to return.
If unset, defaults to [64, 128, 256, 512].
Expand All @@ -35,7 +37,12 @@ def __init__(
"""
super().__init__(**kwargs)

self.backbone = torchvision.models.resnet18(
if variant not in RESNET_VARIANTS:
raise ValueError(
f"ResNet model variant should be in {list(RESNET_VARIANTS.keys())}"
)

self.backbone = RESNET_VARIANTS[variant](
weights="DEFAULT" if download_weights else None
)
self.channels_list = channels_list or [64, 128, 256, 512]
Expand All @@ -57,3 +64,12 @@ def forward(self, x: Tensor) -> list[Tensor]:
outs.append(x)

return outs


RESNET_VARIANTS = {
"18": torchvision.models.resnet18,
"34": torchvision.models.resnet34,
"50": torchvision.models.resnet50,
"101": torchvision.models.resnet101,
"152": torchvision.models.resnet152,
}
4 changes: 2 additions & 2 deletions media/coverage_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit afade1f

Please sign in to comment.