From 2579d02dedf39d2b42ea866ef849f547d3684600 Mon Sep 17 00:00:00 2001 From: JacksonBurns Date: Thu, 16 Jan 2025 16:01:53 -0500 Subject: [PATCH] allow disabling scaling, in case you already did it outside of predict --- fastprop/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/fastprop/model.py b/fastprop/model.py index c5c2be1..eaa0bf9 100644 --- a/fastprop/model.py +++ b/fastprop/model.py @@ -161,17 +161,18 @@ def test_step(self, batch, batch_idx): self._human_loss(y_hat, batch, "test") return loss - def predict_step(self, batch: Tuple[torch.Tensor]): + def predict_step(self, batch: Tuple[torch.Tensor], rescale: bool = True): """Applies feature scaling and appropriate activation function to a Tensor of descriptors. Args: batch (tuple[torch.Tensor]): Unscaled descriptors. + rescale (bool, optional): Apply rescaling according to trained means and vars. Default True. Returns: torch.Tensor: Predictions. """ descriptors = batch[0] - if self.feature_means is not None and self.feature_vars is not None: + if rescale and self.feature_means is not None and self.feature_vars is not None: descriptors = standard_scale(descriptors, self.feature_means, self.feature_vars) with torch.inference_mode(): logits = self.forward(descriptors)