Skip to content

Commit

Permalink
Updated weights loading for DDRNet (#148)
Browse files Browse the repository at this point in the history
  • Loading branch information
klemen1999 authored Jan 5, 2025
1 parent b315fd9 commit 6076408
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
2 changes: 1 addition & 1 deletion luxonis_train/nodes/backbones/ddrnet/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def get_variant(variant: Literal["23-slim", "23"]) -> DDRNetVariant:
"23-slim": DDRNetVariant(
channels=32,
highres_channels=64,
weights_path="https://github.com/luxonis/luxonis-train/releases/download/v0.1.0-beta/ddrnet_23slim_coco.ckpt",
weights_path="https://github.com/luxonis/luxonis-train/releases/download/v0.2.1-beta/ddrnet_23slim_coco.ckpt",
),
"23": DDRNetVariant(
channels=64,
Expand Down
19 changes: 16 additions & 3 deletions luxonis_train/nodes/heads/ddrnet_segmentation_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(
(self.in_height, self.in_width), (model_in_h, model_in_w)
)
self.scale_factor = scale_factor

if (
inter_mode == "pixel_shuffle"
and inter_channels % (scale_factor**2) != 0
Expand Down Expand Up @@ -95,9 +94,23 @@ def __init__(
if inter_mode == "pixel_shuffle"
else nn.Upsample(scale_factor=scale_factor, mode=inter_mode)
)

if download_weights:
weights_path = "https://github.com/luxonis/luxonis-train/releases/download/v0.1.0-beta/ddrnet_head_coco.ckpt"
self.load_checkpoint(weights_path, strict=False)
weights_path = self.get_variant_weights()
if weights_path:
self.load_checkpoint(path=weights_path, strict=False)
else:
logger.warning(
f"No checkpoint available for {self.name}, skipping."
)

def get_variant_weights(self) -> str | None:
if self.in_channels == 128: # light predefined model
return "https://github.com/luxonis/luxonis-train/releases/download/v0.2.1-beta/ddrnet_head_23slim_coco.ckpt"
elif self.in_channels == 256: # heavy predefined model
return "https://github.com/luxonis/luxonis-train/releases/download/v0.2.1-beta/ddrnet_head_23_coco.ckpt"
else:
return None

def forward(self, inputs: Tensor) -> Tensor:
x: Tensor = self.relu(self.bn1(inputs))
Expand Down

0 comments on commit 6076408

Please sign in to comment.