Skip to content

Commit

Permalink
allow disabling scaling, in case you already did it outside of predict
Browse files Browse the repository at this point in the history
  • Loading branch information
JacksonBurns committed Jan 16, 2025
1 parent 7b720fa commit 2579d02
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions fastprop/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,17 +161,18 @@ def test_step(self, batch, batch_idx):
self._human_loss(y_hat, batch, "test")
return loss

def predict_step(self, batch: Tuple[torch.Tensor]):
def predict_step(self, batch: Tuple[torch.Tensor], rescale: bool = True):
"""Applies feature scaling and appropriate activation function to a Tensor of descriptors.
Args:
batch (tuple[torch.Tensor]): Unscaled descriptors.
rescale (bool, optional): Apply rescaling according to trained means and vars. Default True.
Returns:
torch.Tensor: Predictions.
"""
descriptors = batch[0]
if self.feature_means is not None and self.feature_vars is not None:
if rescale and self.feature_means is not None and self.feature_vars is not None:
descriptors = standard_scale(descriptors, self.feature_means, self.feature_vars)
with torch.inference_mode():
logits = self.forward(descriptors)
Expand Down

0 comments on commit 2579d02

Please sign in to comment.