Skip to content

Commit

Permalink
auto-reduce mace batch size for small training sets
Browse files Browse the repository at this point in the history
  • Loading branch information
svandenhaute committed Nov 1, 2023
1 parent 0401aec commit b700c42
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
2 changes: 2 additions & 0 deletions psiflow/models/train_mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ def timeout_handler(signum, frame):
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")

args.batch_size = min(len(collections.train), args.batch_size)
print("actual batch size: {}".format(args.batch_size))
train_loader = torch_geometric.dataloader.DataLoader(
dataset=[
data.AtomicData.from_config(
Expand Down
7 changes: 6 additions & 1 deletion tests/test_models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import ast
import copy
import os

import numpy as np
Expand Down Expand Up @@ -225,7 +226,11 @@ def test_mace_init(mace_config, dataset):
assert "1:" in initialized_config["E0s"]
assert "29:" in initialized_config["E0s"]

model = MACEModel(mace_config)
config = copy.deepcopy(mace_config)
config[
"batch_size"
] = 100000 # bigger than ntrain --> should get reduced internally
model = MACEModel(config)
model.seed = 1
model.initialize(dataset[:3])
assert isinstance(model.model_future, DataFuture)
Expand Down

0 comments on commit b700c42

Please sign in to comment.