Skip to content

Commit

Permalink
refactor: Improve code
Browse files Browse the repository at this point in the history
  • Loading branch information
lorenzomammana committed Jul 8, 2024
1 parent 9323a11 commit 33fa210
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion quadra/tasks/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ def test(self) -> None:
batch_labels = batch_item["label"]
image_labels.extend(batch_labels.tolist())
image_paths.extend(batch_item["image_path"])
batch_images = batch_images.to(self.device).to(self.deployment_model.model_dtype)
batch_images = batch_images.to(device=self.device, dtype=self.deployment_model.model_dtype)
if self.model_data.get("anomaly_method") == "efficientad":
model_output = self.deployment_model(batch_images, None)
else:
Expand Down
2 changes: 1 addition & 1 deletion quadra/tasks/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,7 @@ def test(self) -> None:
with torch.set_grad_enabled(self.gradcam):
for batch_item in tqdm(test_dataloader):
im, target = batch_item
im = im.to(self.device).to(self.deployment_model.model_dtype).detach()
im = im.to(device=self.device, dtype=self.deployment_model.model_dtype).detach()

if self.gradcam:
# When gradcam is used we need to remove gradients
Expand Down
2 changes: 1 addition & 1 deletion quadra/tasks/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test(self) -> None:
image_list, mask_list, mask_pred_list, label_list = [], [], [], []
for batch in dataloader:
images, masks, labels = batch
images = images.to(self.device).to(self.deployment_model.model_dtype)
images = images.to(device=self.device, dtype=self.deployment_model.model_dtype)
if len(masks.shape) == 3: # BxHxW -> Bx1xHxW
masks = masks.unsqueeze(1)
with torch.no_grad():
Expand Down

0 comments on commit 33fa210

Please sign in to comment.