Skip to content

Commit

Permalink
feat: Improve vram usage for patchcore models (#29)
Browse files Browse the repository at this point in the history
* refactor: Avoid using larger projection

* refactor: Preallocate embeddings to avoid huge memory spikes on stack

* refactor: Improve memory usage and empty cache more frequently

* refactor: Load memory bank on gpu only after cache is freed

* build: Upgrade version, update docstring

* refactor: Include max_epochs in pre allocation

* build: Upgrade version

* fix: Fix wrong empty batch size inizialization and filling
  • Loading branch information
lorenzomammana authored Oct 24, 2024
1 parent d8448ec commit c3f32d4
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 14 deletions.
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@ All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [v0.7.0.dev142]

### Updated

- Improve VRAM usage for patchcore trainings.

### Fixed

- Avoid projecting features to a larger space when Johnson-Lindenstrauss lemma suggests it.

## [v0.7.0.dev141]

### Updated
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "anomalib-orobix"
version = "0.7.0.dev141"
version = "0.7.0.dev142"
description = "Orobix anomalib fork"
authors = [
"Intel OpenVINO <[email protected]>",
Expand Down
2 changes: 1 addition & 1 deletion src/anomalib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
# SPDX-License-Identifier: Apache-2.0

anomalib_version = "0.7.0"
custom_orobix_version = "1.4.1"
custom_orobix_version = "1.4.2"

__version__ = f"{anomalib_version}.dev{custom_orobix_version.replace('.', '')}"
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def __init__(self, eps: float = 0.1, random_state: Optional[int] = None) -> None
self.sparse_random_matrix: Tensor
self.eps = eps
self.random_state = random_state
self._use_identity = False

def _sparse_random_matrix(self, n_features: int):
"""Random sparse matrix. Based on https://web.stanford.edu/~hastie/Papers/Ping/KDD06_rp.pdf.
Expand Down Expand Up @@ -111,13 +112,16 @@ def fit(self, embedding: Tensor) -> SparseRandomProjection:
device = embedding.device

self.n_components = self.johnson_lindenstrauss_min_dim(n_samples=n_samples, eps=self.eps)
# TODO: What if n_components > n_features?

# Generate projection matrix
# torch can't multiply directly on sparse matrix and moving sparse matrix to cuda throws error
# (Could not run 'aten::empty_strided' with arguments from the 'SparseCsrCUDA' backend)
# hence sparse matrix is stored as a dense matrix on the device
self.sparse_random_matrix = self._sparse_random_matrix(n_features=n_features).to(device)
if self.n_components < n_features:
# Generate projection matrix
# torch can't multiply directly on sparse matrix and moving sparse matrix to cuda throws error
# (Could not run 'aten::empty_strided' with arguments from the 'SparseCsrCUDA' backend)
# hence sparse matrix is stored as a dense matrix on the device
self.sparse_random_matrix = self._sparse_random_matrix(n_features=n_features).to(device)
else:
self.sparse_random_matrix = torch.tensor([])
self._use_identity = True

return self

Expand All @@ -138,6 +142,9 @@ def transform(self, embedding: Tensor) -> Tensor:
if self.sparse_random_matrix is None:
raise NotFittedError("`fit()` has not been called on SparseRandomProjection yet.")

if self._use_identity:
return embedding

if embedding.dtype == torch.float32:
projected_embedding = embedding @ self.sparse_random_matrix.T.float()
else:
Expand Down
38 changes: 33 additions & 5 deletions src/anomalib/models/patchcore/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ def __init__(
anomaly_score_from_max_heatmap=anomaly_score_from_max_heatmap,
)
self.coreset_sampling_ratio = coreset_sampling_ratio
self.embeddings: list[Tensor] = []
self.embeddings: Tensor = torch.tensor([])
self.automatic_optimization = False
self.coreset_sampler = coreset_sampler
self.counter = 0

def configure_optimizers(self) -> None:
"""Configure optimizers.
Expand All @@ -92,7 +93,8 @@ def configure_optimizers(self) -> None:
return None

def on_train_epoch_start(self) -> None:
self.embeddings = []
self.embeddings = torch.tensor([])
self.counter = 0
return super().on_train_epoch_start()

def training_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> None:
Expand All @@ -113,7 +115,29 @@ def training_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> None
# store the training set embedding. We manually append these
# values mainly due to the new order of hooks introduced after PL v1.4.0
# https://github.com/PyTorchLightning/pytorch-lightning/pull/7357
self.embeddings.append(embedding)

if len(self.embeddings) == 0:
if not self.trainer.sanity_checking:
# Initialize the embeddings tensor with the estimated number of batches
self.embeddings = torch.zeros(
(
(embedding.shape[0] // self.trainer.train_dataloader.batch_size)
* len(self.trainer.train_dataloader.dataset)
* self.trainer.max_epochs,
*embedding.shape[1:],
),
device=self.device,
dtype=embedding.dtype,
)
else:
self.embeddings = self.embeddings.to(device=self.device, dtype=embedding.dtype)

if not self.trainer.sanity_checking:
self.embeddings[self.counter : self.counter + embedding.shape[0]] = embedding
else:
self.embeddings = torch.cat((self.embeddings, embedding))

self.counter += embedding.shape[0]
zero_loss = torch.tensor(0.0, requires_grad=True, device=self.device)
return {"loss": zero_loss}

Expand All @@ -133,10 +157,14 @@ def on_validation_start(self) -> None:
self.model.eval()

logger.info("Aggregating the embedding extracted from the training set.")
embeddings = torch.vstack(self.embeddings)

logger.info("Applying core-set subsampling to get the embedding.")
self.model.subsample_embedding(embeddings, self.coreset_sampling_ratio, mode=self.coreset_sampler)

self.model.subsample_embedding(self.embeddings, self.coreset_sampling_ratio, mode=self.coreset_sampler)
self.embeddings = torch.tensor([])

if torch.cuda.is_available():
torch.cuda.empty_cache()

def validation_step(self, batch: dict[str, str | Tensor], *args, **kwargs) -> STEP_OUTPUT:
"""Get batch of anomaly maps from input image batch.
Expand Down
22 changes: 21 additions & 1 deletion src/anomalib/models/patchcore/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,15 +230,35 @@ def subsample_embedding_anomalib(self, embedding: torch.Tensor, sampling_ratio:
"""
log.info("Subsampling embedding with anomalib coreset sampling")
self.projection_model.fit(embedding)
compressed_embedding = self.projection_model.transform(embedding)
original_device = embedding.device

if not self.projection_model._use_identity:
compressed_embedding = self.projection_model.transform(embedding)

if torch.cuda.is_available() and embedding.device != "cpu":
original_device = embedding.device
# Unload the embedding from the GPU to free up memory
embedding = embedding.to("cpu")
torch.cuda.empty_cache()
else:
compressed_embedding = embedding

# Coreset Subsampling
sampler = KCenterGreedy(sampling_ratio=sampling_ratio)
coreset_indices = sampler.sample_coreset(compressed_embedding)
if self.compress_memory_bank:
self.memory_bank = compressed_embedding[coreset_indices]
del compressed_embedding
else:
del compressed_embedding
if torch.cuda.is_available():
torch.cuda.empty_cache()

self.memory_bank = embedding[coreset_indices]
self.memory_bank = self.memory_bank.to(original_device)

if torch.cuda.is_available():
torch.cuda.empty_cache()

def subsample_embedding_amazon(self, embedding: torch.Tensor, sampling_ratio: float) -> None:
"""Subsample embedding based on coreset sampling and store to memory.
Expand Down

0 comments on commit c3f32d4

Please sign in to comment.