Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resnet variants #9

Merged
merged 5 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions configs/resnet_model.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@

model:
name: resnet50_classification
nodes:
- name: ResNet
variant: "50"
download_weights: True

- name: ClassificationHead
inputs:
- ResNet

losses:
- name: CrossEntropyLoss
attached_to: ClassificationHead

metrics:
- name: Accuracy
is_main_metric: true
attached_to: ClassificationHead

visualizers:
- name: ClassificationVisualizer
attached_to: ClassificationHead
params:
font_scale: 0.5
color: [255, 0, 0]
thickness: 2
include_plot: True

dataset:
name: cifar10_test

trainer:
batch_size: 4
epochs: &epochs 200
num_workers: 4
validation_interval: 10
num_log_images: 8

preprocessing:
train_image_size: [&height 224, &width 224]
keep_aspect_ratio: False
normalize:
active: True

callbacks:
- name: ExportOnTrainEnd
- name: TestOnTrainEnd

optimizer:
name: SGD
params:
lr: 0.02

scheduler:
name: ConstantLR
13 changes: 7 additions & 6 deletions luxonis_train/nodes/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ arbitrarily as long as the two nodes are compatible with each other.

## Table Of Contents

- [ResNet18](#resnet18)
- [ResNet](#resnet)
- [MicroNet](#micronet)
- [RepVGG](#repvgg)
- [EfficientRep](#efficientrep)
Expand All @@ -30,15 +30,16 @@ Every node takes these parameters:

Additional parameters for specific nodes are listed below.

## ResNet18
## ResNet

Adapted from [here](https://pytorch.org/vision/main/models/generated/torchvision.models.resnet18.html).
Adapted from [here](https://pytorch.org/vision/main/models/resnet.html).

**Params**

| Key | Type | Default value | Description |
| ---------------- | ---- | ------------- | -------------------------------------- |
| download_weights | bool | False | If True download weights from imagenet |
| Key | Type | Default value | Description |
| ---------------- | ----------------------------------------- | ------------- | -------------------------------------- |
| variant | Literal\["18", "34", "50", "101", "152"\] | "18" | Variant of the network. |
| download_weights | bool | False | If True download weights from imagenet |

## MicroNet

Expand Down
4 changes: 2 additions & 2 deletions luxonis_train/nodes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .mobileone import MobileOne
from .reppan_neck import RepPANNeck
from .repvgg import RepVGG
from .resnet18 import ResNet18
from .resnet import ResNet
from .rexnetv1 import ReXNetV1_lite
from .segmentation_head import SegmentationHead

Expand All @@ -28,6 +28,6 @@
"ReXNetV1_lite",
"RepPANNeck",
"RepVGG",
"ResNet18",
"ResNet",
"SegmentationHead",
]
30 changes: 23 additions & 7 deletions luxonis_train/nodes/resnet18.py → luxonis_train/nodes/resnet.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
"""ResNet18 backbone.
"""ResNet backbone.

Source: U{https://pytorch.org/vision/main/models/generated/
torchvision.models.resnet18.html}
Source: U{https://pytorch.org/vision/main/models/resnet.html}
@license: U{PyTorch<https://github.com/pytorch/pytorch/blob/master/LICENSE>}
"""

from typing import Literal

import torchvision
from torch import Tensor

from .base_node import BaseNode


class ResNet18(BaseNode[Tensor, list[Tensor]]):
class ResNet(BaseNode[Tensor, list[Tensor]]):
attach_index: int = -1

def __init__(
self,
variant: Literal["18", "34", "50", "101", "152"] = "18",
channels_list: list[int] | None = None,
download_weights: bool = False,
**kwargs,
):
"""Implementation of the ResNet18 backbone.
"""Implementation of the ResNetX backbone.

TODO: add more info

@type variant: Literal["18", "34", "50", "101", "152"]
@param variant: ResNet variant. Defaults to "18".
@type channels_list: list[int] | None
@param channels_list: List of channels to return.
If unset, defaults to [64, 128, 256, 512].
Expand All @@ -35,7 +37,12 @@ def __init__(
"""
super().__init__(**kwargs)

self.backbone = torchvision.models.resnet18(
if variant not in RESNET_VARIANTS:
raise ValueError(
f"ResNet model variant should be in {list(RESNET_VARIANTS.keys())}"
)

self.backbone = RESNET_VARIANTS[variant](
weights="DEFAULT" if download_weights else None
)
self.channels_list = channels_list or [64, 128, 256, 512]
Expand All @@ -57,3 +64,12 @@ def forward(self, x: Tensor) -> list[Tensor]:
outs.append(x)

return outs


RESNET_VARIANTS = {
"18": torchvision.models.resnet18,
"34": torchvision.models.resnet34,
"50": torchvision.models.resnet50,
"101": torchvision.models.resnet101,
"152": torchvision.models.resnet152,
}
4 changes: 2 additions & 2 deletions media/coverage_badge.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.