Skip to content

Commit

Permalink
Fix/load pretrained weights (#25)
Browse files Browse the repository at this point in the history
* fix: Fix pretrained weights loading for patchcore, remove deprecated feature extractor

* refactor: Remove deprecated class usage

* refactor: Minor code improvements, add comments

* build: Update version, update changelog

* refactor: Update deprecated class usage
  • Loading branch information
lorenzomammana authored Jun 4, 2024
1 parent 9f697b2 commit 02953d3
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 26 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [v0.7.0+obx.1.3.3]

### Fixed

- Fix pretrained weights not loaded properly in patchcore

### Updated

- Use TimmFeatureExtractor instead of the deprecated FeatureExtractor

## [v0.7.0+obx.1.3.2]

### Fixed
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# SPDX-License-Identifier: Apache-2.0

anomalib_version = "0.7.0"
custom_orobix_version = "1.3.2"
custom_orobix_version = "1.3.3"

__version__ = f"{anomalib_version}+obx.{custom_orobix_version}"
10 changes: 6 additions & 4 deletions src/anomalib/models/efficient_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,15 @@ def __init__(
def prepare_pretrained_model(self) -> None:
if self.model.pretrained_teacher_type == "nelson":
if self.model.model_size == EfficientAdModelSize.S:
teacher_path = self.pretrained_models_dir / NELSON_TEACHER_S.name
download_single_file(self.pretrained_models_dir, NELSON_TEACHER_S)
model_type = NELSON_TEACHER_S
else:
teacher_path = self.pretrained_models_dir / NELSON_TEACHER_M.name
download_single_file(self.pretrained_models_dir, NELSON_TEACHER_M)
model_type = NELSON_TEACHER_M
teacher_path = self.pretrained_models_dir / model_type.name
# Attempt to download if file does not exist
download_single_file(self.pretrained_models_dir, model_type)
else:
download_and_extract(self.pretrained_models_dir, WEIGHTS_DOWNLOAD_INFO)
# Why is it nelson also here?
teacher_path = (
self.pretrained_models_dir / "efficientad_pretrained_weights" / f"nelson_teacher_{self.model_size}.pth"
)
Expand Down
3 changes: 2 additions & 1 deletion src/anomalib/models/padim/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from anomalib.models.components import FeatureExtractor, MultiVariateGaussian
from anomalib.models.components.feature_extractors import dryrun_find_featuremap_dims
from anomalib.models.components.feature_extractors.timm import TimmFeatureExtractor
from anomalib.models.padim.anomaly_map import AnomalyMapGenerator
from anomalib.pre_processing import Tiler

Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(
self.layers = layers
self.backbone = backbone

self.feature_extractor = FeatureExtractor(
self.feature_extractor = TimmFeatureExtractor(
backbone=self.backbone, layers=layers, pre_trained=pre_trained, pretrained_weights=pretrained_weights
)
self.n_features_original, self.n_patches = _deduce_dims(self.feature_extractor, input_size, self.layers)
Expand Down
24 changes: 4 additions & 20 deletions src/anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@

from anomalib.models.components import (
DynamicBufferModule,
FeatureExtractor,
KCenterGreedy,
)
from anomalib.models.components.dimensionality_reduction.random_projection import SparseRandomProjection
from anomalib.models.components.feature_extractors.timm import TimmFeatureExtractor
from anomalib.models.components.sampling.amazon_k_center_greedy import ApproximateGreedyCoresetSampler
from anomalib.models.patchcore.anomaly_map import AnomalyMapGenerator
from anomalib.pre_processing import Tiler
Expand Down Expand Up @@ -72,25 +72,9 @@ def __init__(
self.num_neighbors: Tensor
self.score_computation = score_computation

# TODO: Hardcoded stuff I think for ssl?
if pretrained_weights is not None and not isinstance(self.backbone, str):
log.info("Loading pretrained weights")

with open(pretrained_weights, "rb") as f:
weights = torch.load(f)

new_state_dict = OrderedDict()

for key, value in weights["state_dict"].items():
if "student" in key or "teacher" in key:
continue

new_key = key.replace("model.features_extractor.", "")
new_state_dict[new_key] = value

self.backbone.load_state_dict(new_state_dict, strict=False)

self.feature_extractor = FeatureExtractor(backbone=self.backbone, layers=self.layers, pre_trained=pre_trained)
self.feature_extractor = TimmFeatureExtractor(
backbone=self.backbone, layers=self.layers, pre_trained=pre_trained, pretrained_weights=pretrained_weights
)
self.feature_pooler = torch.nn.AvgPool2d(3, 1, 1)
self.anomaly_map_generator = AnomalyMapGenerator(input_size=input_size)

Expand Down

0 comments on commit 02953d3

Please sign in to comment.