diff --git a/CHANGELOG.md b/CHANGELOG.md index e7014a36..6b937d84 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,13 @@ # Changelog All notable changes to this project will be documented in this file. +### [2.1.9] + +#### Updated + +- Update anomalib to v0.7.0+obx.1.3.3 +- Update network builders to support loading model checkpoints from disk + ### [2.1.8] #### Added diff --git a/poetry.lock b/poetry.lock index a519ba4d..fa4da03e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -187,7 +187,7 @@ test = ["flake8 (==3.7.9)", "mock (==2.0.0)", "pylint (==1.9.3)"] [[package]] name = "anomalib" -version = "0.7.0+obx.1.3.2" +version = "0.7.0+obx.1.3.3" description = "anomalib - Anomaly Detection Library" optional = false python-versions = ">=3.7" @@ -214,8 +214,8 @@ openvino = ["defusedxml (==0.7.1)", "networkx (>=2.5,<3.0)", "nncf (>=2.1.0)", " [package.source] type = "git" url = "https://github.com/orobix/anomalib.git" -reference = "v0.7.0+obx.1.3.2" -resolved_reference = "9f697b2ab32138515fade4bedbf166b4f6901170" +reference = "v0.7.0+obx.1.3.3" +resolved_reference = "02953d367168659ad9324ca03a9fa9c63aa69dbe" [[package]] name = "antlr4-python3-runtime" @@ -6937,4 +6937,4 @@ onnx = ["onnx", "onnxconverter-common", "onnxruntime_gpu", "onnxsim"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.11" -content-hash = "68361649693152e3d71b1b85bb326dfa4984c9b170c1bee846937588b023fd70" +content-hash = "d4b357468bfe398840abf8c9011767d00ddc950f14176894106e748a5fdcd649" diff --git a/pyproject.toml b/pyproject.toml index c93758c6..c5ec58bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "quadra" -version = "2.1.8" +version = "2.1.9" description = "Deep Learning experiment orchestration library" authors = [ "Federico Belotti ", @@ -77,7 +77,7 @@ h5py = "~3.8" timm = "0.9.12" # Right now only this ref supports timm 0.9.12 segmentation_models_pytorch = { git = "https://github.com/qubvel/segmentation_models.pytorch", rev = "7b381f899ed472a477a89d381689caf535b5d0a6" } -anomalib = { git = "https://github.com/orobix/anomalib.git", tag = "v0.7.0+obx.1.3.2" } +anomalib = { git = "https://github.com/orobix/anomalib.git", tag = "v0.7.0+obx.1.3.3" } xxhash = "~3.2" torchinfo = "~1.8" typing_extensions = { version = "4.11.0", python = "<3.10" } diff --git a/quadra/__init__.py b/quadra/__init__.py index 8d00c9d9..cc92e31d 100644 --- a/quadra/__init__.py +++ b/quadra/__init__.py @@ -1,4 +1,4 @@ -__version__ = "2.1.8" +__version__ = "2.1.9" def get_version(): diff --git a/quadra/models/classification/backbones.py b/quadra/models/classification/backbones.py index f3651292..3bf48d95 100644 --- a/quadra/models/classification/backbones.py +++ b/quadra/models/classification/backbones.py @@ -4,10 +4,14 @@ import timm import torch +from timm.models.helpers import load_checkpoint from torch import nn from torchvision import models from quadra.models.classification.base import BaseNetworkBuilder +from quadra.utils.logger import get_logger + +log = get_logger(__name__) class TorchHubNetworkBuilder(BaseNetworkBuilder): @@ -22,6 +26,7 @@ class TorchHubNetworkBuilder(BaseNetworkBuilder): freeze: Whether to freeze the feature extractor. Defaults to True. hyperspherical: Whether to map features to an hypersphere. Defaults to False. flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True. + checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None. **torch_hub_kwargs: Additional arguments to pass to torch.hub.load """ @@ -35,12 +40,17 @@ def __init__( freeze: bool = True, hyperspherical: bool = False, flatten_features: bool = True, + checkpoint_path: str | None = None, **torch_hub_kwargs: Any, ): self.pretrained = pretrained features_extractor = torch.hub.load( repo_or_dir=repo_or_dir, model=model_name, pretrained=self.pretrained, **torch_hub_kwargs ) + if checkpoint_path: + log.info("Loading checkpoint from %s", checkpoint_path) + load_checkpoint(model=features_extractor, checkpoint_path=checkpoint_path) + super().__init__( features_extractor=features_extractor, pre_classifier=pre_classifier, @@ -62,6 +72,7 @@ class TorchVisionNetworkBuilder(BaseNetworkBuilder): freeze: Whether to freeze the feature extractor. Defaults to True. hyperspherical: Whether to map features to an hypersphere. Defaults to False. flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True. + checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None. **torchvision_kwargs: Additional arguments to pass to the model function. """ @@ -74,11 +85,16 @@ def __init__( freeze: bool = True, hyperspherical: bool = False, flatten_features: bool = True, + checkpoint_path: str | None = None, **torchvision_kwargs: Any, ): self.pretrained = pretrained model_function = models.__dict__[model_name] features_extractor = model_function(pretrained=self.pretrained, progress=True, **torchvision_kwargs) + if checkpoint_path: + log.info("Loading checkpoint from %s", checkpoint_path) + load_checkpoint(model=features_extractor, checkpoint_path=checkpoint_path) + # Remove classifier features_extractor.classifier = nn.Identity() super().__init__( @@ -102,6 +118,7 @@ class TimmNetworkBuilder(BaseNetworkBuilder): freeze: Whether to freeze the feature extractor. Defaults to True. hyperspherical: Whether to map features to an hypersphere. Defaults to False. flatten_features: Whether to flatten the features before the pre_classifier. Defaults to True. + checkpoint_path: Path to a checkpoint to load after the model is initialized. Defaults to None. **timm_kwargs: Additional arguments to pass to timm.create_model """ @@ -114,10 +131,13 @@ def __init__( freeze: bool = True, hyperspherical: bool = False, flatten_features: bool = True, + checkpoint_path: str | None = None, **timm_kwargs: Any, ): self.pretrained = pretrained - features_extractor = timm.create_model(model_name, pretrained=self.pretrained, num_classes=0, **timm_kwargs) + features_extractor = timm.create_model( + model_name, pretrained=self.pretrained, num_classes=0, checkpoint_path=checkpoint_path, **timm_kwargs + ) super().__init__( features_extractor=features_extractor,