From 607640829d946f91a96de195967a090ec838d6dc Mon Sep 17 00:00:00 2001 From: KlemenSkrlj <47853619+klemen1999@users.noreply.github.com> Date: Sun, 5 Jan 2025 19:27:48 +0100 Subject: [PATCH] Updated weights loading for DDRNet (#148) --- .../nodes/backbones/ddrnet/variants.py | 2 +- .../nodes/heads/ddrnet_segmentation_head.py | 19 ++++++++++++++++--- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/luxonis_train/nodes/backbones/ddrnet/variants.py b/luxonis_train/nodes/backbones/ddrnet/variants.py index 15ba986d..97d22fde 100644 --- a/luxonis_train/nodes/backbones/ddrnet/variants.py +++ b/luxonis_train/nodes/backbones/ddrnet/variants.py @@ -14,7 +14,7 @@ def get_variant(variant: Literal["23-slim", "23"]) -> DDRNetVariant: "23-slim": DDRNetVariant( channels=32, highres_channels=64, - weights_path="https://github.com/luxonis/luxonis-train/releases/download/v0.1.0-beta/ddrnet_23slim_coco.ckpt", + weights_path="https://github.com/luxonis/luxonis-train/releases/download/v0.2.1-beta/ddrnet_23slim_coco.ckpt", ), "23": DDRNetVariant( channels=64, diff --git a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py index 001c52ab..2b313ab6 100644 --- a/luxonis_train/nodes/heads/ddrnet_segmentation_head.py +++ b/luxonis_train/nodes/heads/ddrnet_segmentation_head.py @@ -61,7 +61,6 @@ def __init__( (self.in_height, self.in_width), (model_in_h, model_in_w) ) self.scale_factor = scale_factor - if ( inter_mode == "pixel_shuffle" and inter_channels % (scale_factor**2) != 0 @@ -95,9 +94,23 @@ def __init__( if inter_mode == "pixel_shuffle" else nn.Upsample(scale_factor=scale_factor, mode=inter_mode) ) + if download_weights: - weights_path = "https://github.com/luxonis/luxonis-train/releases/download/v0.1.0-beta/ddrnet_head_coco.ckpt" - self.load_checkpoint(weights_path, strict=False) + weights_path = self.get_variant_weights() + if weights_path: + self.load_checkpoint(path=weights_path, strict=False) + else: + logger.warning( + f"No checkpoint available for {self.name}, skipping." + ) + + def get_variant_weights(self) -> str | None: + if self.in_channels == 128: # light predefined model + return "https://github.com/luxonis/luxonis-train/releases/download/v0.2.1-beta/ddrnet_head_23slim_coco.ckpt" + elif self.in_channels == 256: # heavy predefined model + return "https://github.com/luxonis/luxonis-train/releases/download/v0.2.1-beta/ddrnet_head_23_coco.ckpt" + else: + return None def forward(self, inputs: Tensor) -> Tensor: x: Tensor = self.relu(self.bn1(inputs))