Skip to content

Commit

Permalink
fix: Use correct batch size when the dataset is shorter than the actu…
Browse files Browse the repository at this point in the history
…al batch size (#30)

* fix: Use correct batch size when the dataset is shorter than the actual batch size

* build: Update version
  • Loading branch information
lorenzomammana authored Oct 24, 2024
1 parent c3f32d4 commit 7b9a1ab
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 3 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [v0.7.0.dev143]

### Fixed

- Fix wrong embeddings initialization in patchcore when the dataset is smaller than the batch size.

## [v0.7.0.dev142]

### Updated
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "anomalib-orobix"
version = "0.7.0.dev142"
version = "0.7.0.dev143"
description = "Orobix anomalib fork"
authors = [
"Intel OpenVINO <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# SPDX-License-Identifier: Apache-2.0

anomalib_version = "0.7.0"
custom_orobix_version = "1.4.2"
custom_orobix_version = "1.4.3"

__version__ = f"{anomalib_version}.dev{custom_orobix_version.replace('.', '')}"
2 changes: 1 addition & 1 deletion src/anomalib/models/patchcore/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def training_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> None
# Initialize the embeddings tensor with the estimated number of batches
self.embeddings = torch.zeros(
(
(embedding.shape[0] // self.trainer.train_dataloader.batch_size)
(embedding.shape[0] // batch["image"].shape[0])
* len(self.trainer.train_dataloader.dataset)
* self.trainer.max_epochs,
*embedding.shape[1:],
Expand Down

0 comments on commit 7b9a1ab

Please sign in to comment.