Skip to content

Commit

Permalink
Train-Only Heads (#83)
Browse files Browse the repository at this point in the history
  • Loading branch information
JSabadin authored Sep 28, 2024
1 parent a7048e2 commit 121c8b9
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 1 deletion.
1 change: 1 addition & 0 deletions luxonis_train/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class ModelNodeConfig(BaseModelExtraForbid):
inputs: list[str] = [] # From preceding nodes
input_sources: list[str] = [] # From data loader
freezing: FreezingConfig = FreezingConfig()
remove_on_export: bool = False
task: str | dict[TaskType, str] | None = None
params: Params = {}

Expand Down
8 changes: 7 additions & 1 deletion luxonis_train/models/luxonis_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,11 @@ def __init__(
}
nodes[node_name] = (
Node,
{**node_cfg.params, "_tasks": node_cfg.task},
{
**node_cfg.params,
"_tasks": node_cfg.task,
"remove_on_export": node_cfg.remove_on_export,
},
)

# Handle inputs for this node
Expand Down Expand Up @@ -373,6 +377,8 @@ def forward(
for node_name, node, input_names, unprocessed in traverse_graph(
self.graph, cast(dict[str, BaseNode], self.nodes)
):
if node.export and node.remove_on_export:
continue
input_names += self.node_input_sources[node_name]

node_inputs: list[Packet[Tensor]] = []
Expand Down
7 changes: 7 additions & 0 deletions luxonis_train/nodes/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
n_classes: int | None = None,
n_keypoints: int | None = None,
in_sizes: Size | list[Size] | None = None,
remove_on_export: bool = False,
attach_index: AttachIndexType | None = None,
_tasks: dict[TaskType, str] | None = None,
):
Expand Down Expand Up @@ -187,6 +188,7 @@ class L{tasks} attribute. Shouldn't be provided by the user in most cases.
self._n_classes = n_classes
self._n_keypoints = n_keypoints
self._export = False
self._remove_on_export = remove_on_export
self._epoch = 0
self._in_sizes = in_sizes

Expand Down Expand Up @@ -507,6 +509,11 @@ def set_export_mode(self, mode: bool = True) -> None:
"""
self._export = mode

@property
def remove_on_export(self) -> bool:
"""Getter for the remove_on_export attribute."""
return self._remove_on_export

def unwrap(self, inputs: list[Packet[Tensor]]) -> ForwardInputT:
"""Prepares inputs for the forward pass.
Expand Down
50 changes: 50 additions & 0 deletions tests/configs/ddrnet.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
model:
name: ddrnet_segmentation
nodes:
- name: DDRNet
- name: DDRNetSegmentationHead
inputs: ["DDRNet"]
alias: "segmentation_head"
params:
attach_index: -1
- name: DDRNetSegmentationHead
inputs: ["DDRNet"]
alias: "aux_segmentation_head"
params:
attach_index: -2
remove_on_export: true

losses:
- attached_to: segmentation_head
name: CrossEntropyLoss
- attached_to: aux_segmentation_head
name: CrossEntropyLoss
trainer:
preprocessing:
train_image_size:
- &height 128
- &width 128
keep_aspect_ratio: False
normalize:
active: True

batch_size: 2
epochs: &epochs 1
num_workers: 8
validation_interval: 10
num_log_images: 8

callbacks:
- name: ExportOnTrainEnd

optimizer:
name: SGD
params:
lr: 0.01
momentum: 0.9
weight_decay: 0.0005

scheduler:
name: CosineAnnealingLR
params:
T_max: *epochs
42 changes: 42 additions & 0 deletions tests/integration/test_remove_on_export.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from pathlib import Path

import onnxruntime as rt
import pytest
from luxonis_ml.data import LuxonisDataset

from luxonis_train.core import LuxonisModel

ONNX_PATH = Path("tests/integration/_ddrnet_model.onnx")


@pytest.fixture(scope="function", autouse=True)
def clear_files():
yield
ONNX_PATH.unlink(missing_ok=True)


def test_train_only_heads(coco_dataset: LuxonisDataset):
config_file = "tests/configs/ddrnet.yaml"

opts = {"loader.params.dataset_name": coco_dataset.dataset_name}

model = LuxonisModel(config_file, opts)
results = model.test()

name_to_check = "aux_segmentation_head"
is_in_results = any(name_to_check in key for key in results)

model.export(str(ONNX_PATH))

sess = rt.InferenceSession(str(ONNX_PATH))
onnx_output_names = [output.name for output in sess.get_outputs()]
is_in_output_names = any(
name_to_check in name for name in onnx_output_names
)

assert (
is_in_results
), "'aux_segmentation_head' should be in the test results"
assert (
not is_in_output_names
), "'aux_segmentation_head' should not be in the ONNX output names"

0 comments on commit 121c8b9

Please sign in to comment.