From b700c42caa45d3bb25e59880f21b8e02c0bc9677 Mon Sep 17 00:00:00 2001 From: svdenhau Date: Wed, 1 Nov 2023 14:31:46 +0100 Subject: [PATCH] auto-reduce mace batch size for small training sets --- psiflow/models/train_mace.py | 2 ++ tests/test_models.py | 7 ++++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/psiflow/models/train_mace.py b/psiflow/models/train_mace.py index d57d4a1..dabc696 100644 --- a/psiflow/models/train_mace.py +++ b/psiflow/models/train_mace.py @@ -213,6 +213,8 @@ def timeout_handler(signum, frame): ) logging.info(f"Atomic energies: {atomic_energies.tolist()}") + args.batch_size = min(len(collections.train), args.batch_size) + print("actual batch size: {}".format(args.batch_size)) train_loader = torch_geometric.dataloader.DataLoader( dataset=[ data.AtomicData.from_config( diff --git a/tests/test_models.py b/tests/test_models.py index 5ad52c8..836a1d7 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,4 +1,5 @@ import ast +import copy import os import numpy as np @@ -225,7 +226,11 @@ def test_mace_init(mace_config, dataset): assert "1:" in initialized_config["E0s"] assert "29:" in initialized_config["E0s"] - model = MACEModel(mace_config) + config = copy.deepcopy(mace_config) + config[ + "batch_size" + ] = 100000 # bigger than ntrain --> should get reduced internally + model = MACEModel(config) model.seed = 1 model.initialize(dataset[:3]) assert isinstance(model.model_future, DataFuture)