Skip to content

Commit

Permalink
updata push option and saver (#101)
Browse files Browse the repository at this point in the history
* chore: update .gitignore

* update build.py: check the model_options is the same as model.model_options

* update deeptb.py:ensure for mix model no push

* update nnsk.py: not to update push in jdata is none and in ckpt is not.

* update saver.py:
1. add symlink to latest in iteration and best in epoch ckpts.
2. add save push ckpt with indicating the rs and w.

* update argcheck.py: add max_ckpt to control number of ckpts to save. default is 4.

* update config_sk.py and config_skenv.py

* test: update test_sktb.py

* update saver.py and test_sktb.py
  • Loading branch information
QG-phy authored Mar 26, 2024
1 parent 1caa479 commit 16f9e20
Show file tree
Hide file tree
Showing 9 changed files with 75 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ dptb/tests/**/*.pth
dptb/tests/**/*.npy
dptb/tests/**/*.traj
dptb/tests/**/out*/*
examples/_*
*.dat
*.vasp
*log*
Expand Down
5 changes: 4 additions & 1 deletion dptb/nn/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ def build_model(run_options, model_options: dict={}, common_options: dict={}, st
model = NNSK.from_reference(checkpoint, **model_options["nnsk"], **common_options)
if init_mixed:
# mix model can be initilized with a mixed reference model or a nnsk model.
model = MIX.from_reference(checkpoint, **model_options, **common_options)
model = MIX.from_reference(checkpoint, **model_options, **common_options)

if model.model_options != model_options:
# log.error("The model options are not consistent with the checkpoint, using the one in the checkpoint.")
raise ValueError("The model options are not consistent with the checkpoint, using the one in the checkpoint.")
return model
6 changes: 6 additions & 0 deletions dptb/nn/deeptb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from dptb.nn.nnsk import NNSK
from e3nn.o3 import Linear
from dptb.nn.rescale import E3PerSpeciesScaleShift, E3PerEdgeSpeciesScaleShift
import logging

log = logging.getLogger(__name__)

""" if this class is called, it suggest user choose a embedding method. If not, it should directly use _sktb.py
"""
Expand Down Expand Up @@ -382,6 +384,10 @@ def from_reference(
if len(nnsk) == 0:
model_options["nnsk"] = ckpt["config"]["model_options"]["nnsk"]

if model_options["nnsk"].get("push") is not None:
model_options["nnsk"]["push"] = None
log.warning("The push option is not supported in the mixed model. The push option is only supported in the nnsk model.")

if len(embedding) == 0 or len(prediction) == 0:
assert ckpt["config"]["model_options"].get("embedding") is not None and ckpt["config"]["model_options"].get("prediction") is not None, \
"The reference model checkpoint should come from a mixed model if dptb info is not provided."
Expand Down
5 changes: 3 additions & 2 deletions dptb/nn/nnsk.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ def __init__(
"onsite": onsite,
"hopping": hopping,
"freeze": freeze,
"push": push,
"push": push,
"std": std
}
}

Expand Down Expand Up @@ -348,7 +349,7 @@ def from_reference(
if v is None:
common_options[k] = f["config"]["common_options"][k]
for k,v in nnsk.items():
if v is None:
if v is None and not k is "push" :
nnsk[k] = f["config"]["model_options"]["nnsk"][k]

model = cls(**common_options, **nnsk)
Expand Down
68 changes: 48 additions & 20 deletions dptb/plugins/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,36 @@ def register(self, trainer, checkpoint_path):
self.checkpoint_path = checkpoint_path
self.trainer = trainer

if self.trainer.model.name == "nnsk":
# 获取 push 选项
push_option = self.trainer.model.model_options["nnsk"].get("push", False)
if push_option:
# 计算所有阈值之和
thrs = sum(abs(val) for key, val in push_option.items() if "thr" in key)
# 如果阈值之和不为 0, 则 push 为 True
push = abs(thrs) != 0.0
else:
push = False
else:
push = False
self.push = push


def iteration(self, **kwargs):
# suffix = "_b"+"%.3f"%self.trainer.common_options["bond_cutoff"]+"_c"+"%.3f"%self.trainer.onsite_options["skfunction"]["sk_cutoff"]+"_w"+\
# "%.3f"%self.trainer.model_options["skfunction"]["sk_decay_w"]
suffix = ".iter{}".format(self.trainer.iter)
if self.push:
suffix = ".iter_rs" + "%.3f"%self.trainer.model.hopping_options["rs"]+"_w"+"%.3f"%self.trainer.model.hopping_options["w"]
# By default, the maximum number of saved checkpoints is 100 for pushing rs and w.
max_ckpt = 100
else:
suffix = ".iter{}".format(self.trainer.iter)
max_ckpt = self.trainer.train_options["max_ckpt"]

name = self.trainer.model.name+suffix
self.latest_quene.append(name)
if len(self.latest_quene) >= 5:

if len(self.latest_quene) > max_ckpt:
delete_name = self.latest_quene.pop(0)
delete_path = os.path.join(self.checkpoint_path, delete_name+".pth")
os.remove(delete_path)
Expand All @@ -40,11 +63,17 @@ def iteration(self, **kwargs):
train_options=self.trainer.train_options,
)

# if self.trainer.name == "dptb" \
# and self.trainer.run_opt["use_correction"] \
# and not self.trainer.run_opt["freeze"]:

# self._save(name="latest_"+self.trainer.name+'_nnsk'+suffix,model=self.trainer.sknet, model_config=self.trainer.sknet_config)
if not self.push:
# 构建一个符号链接,指向最新的模型
latest_symlink = os.path.join(self.checkpoint_path, self.trainer.model.name + ".latest.pth")
if os.path.lexists(latest_symlink):
os.unlink(latest_symlink)
latest_ckpt = os.path.join(self.checkpoint_path, name+".pth")
latest_ckpt_abs_path = os.path.abspath(latest_ckpt)
# 确保源文件存在
if not os.path.exists(latest_ckpt_abs_path):
raise FileNotFoundError(f"Source file {latest_ckpt_abs_path} does not exist.")
os.symlink(latest_ckpt_abs_path, latest_symlink)

def epoch(self, **kwargs):

Expand All @@ -54,14 +83,15 @@ def epoch(self, **kwargs):
else:
updated_loss = self.trainer.stats.get("train_loss").get("epoch_mean",1e6)

max_ckpt = self.trainer.train_options["max_ckpt"]

if updated_loss < self.best_loss:
# suffix = "_b"+"%.3f"%self.trainer.common_options["bond_cutoff"]+"_c"+"%.3f"%self.trainer.model_options["skfunction"]["sk_cutoff"]+"_w"+\
# "%.3f"%self.trainer.model_options["skfunction"]["sk_decay_w"]
suffix = ".ep{}".format(self.trainer.ep)
name = self.trainer.model.name+suffix
self.best_quene.append(name)
if len(self.best_quene) >= 5:
if len(self.best_quene) > max_ckpt:
delete_name = self.best_quene.pop(0)
delete_path = os.path.join(self.checkpoint_path, delete_name+".pth")
os.remove(delete_path)
Expand All @@ -76,18 +106,16 @@ def epoch(self, **kwargs):

self.best_loss = updated_loss

# if self.trainer.name == "dptb" \
# and self.trainer.run_opt["use_correction"] \
# and not self.trainer.run_opt["freeze"]:

# self._save(
# name="best_"+self.trainer.name+'_nnsk'+suffix,
# model=self.trainer.sknet,
# model_config=self.trainer.sknet_config
# common_options=self.trainer.common_options
# )

# log.info(msg="checkpoint saved as {}".format("best_epoch"))
# 构建一个符号链接,指向best模型
best_symlink = os.path.join(self.checkpoint_path, self.trainer.model.name + ".best.pth")
if os.path.lexists(best_symlink):
os.unlink(best_symlink)
best_ckpt = os.path.join(self.checkpoint_path, name+".pth")
best_ckpt_abs_path = os.path.abspath(best_ckpt)
# 确保源文件存在
if not os.path.exists(best_ckpt_abs_path):
raise FileNotFoundError(f"Source file {best_ckpt_abs_path} does not exist.")
os.symlink(best_ckpt_abs_path, best_symlink)

def _save(self, name, model, model_options, common_options, train_options):
obj = {}
Expand Down
12 changes: 6 additions & 6 deletions dptb/tests/test_sktb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_nnsk_valence(root_directory):
def test_nnsk_strain_polar(root_directory):
INPUT_file = root_directory+"/dptb/tests/data/test_sktb/input/input_strain_polar.json"
output = root_directory+"/dptb/tests/data/test_sktb/output"
init_model = root_directory+"/dptb/tests/data/test_sktb/output/test_valence/checkpoint/nnsk.iter2.pth"
init_model = root_directory+"/dptb/tests/data/test_sktb/output/test_valence/checkpoint/nnsk.latest.pth"

check_config_train(INPUT=INPUT_file, init_model=None, restart=None, train_soc=False)
train(INPUT=INPUT_file, init_model=init_model, restart=None, train_soc=False,\
Expand All @@ -36,7 +36,7 @@ def test_nnsk_push(root_directory):
INPUT_file_rs = root_directory + "/dptb/tests/data/test_sktb/input/input_push_rs.json"
INPUT_file_w = root_directory + "/dptb/tests/data/test_sktb/input/input_push_w.json"
output = root_directory + "/dptb/tests/data/test_sktb/output"
init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_strain_polar/checkpoint/nnsk.ep2.pth"
init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_strain_polar/checkpoint/nnsk.best.pth"

check_config_train(INPUT=INPUT_file_rs, init_model=None, restart=None, train_soc=False)
train(INPUT=INPUT_file_rs, init_model=init_model, restart=None, train_soc=False,\
Expand All @@ -46,8 +46,8 @@ def test_nnsk_push(root_directory):
train(INPUT=INPUT_file_w, init_model=init_model, restart=None, train_soc=False,\
output=output+"/test_push_w", log_level=5, log_path=output+"/test_push_w.log")

model_rs = torch.load(f"{root_directory}/dptb/tests/data/test_sktb/output/test_push_rs/checkpoint/nnsk.iter10.pth")
model_w = torch.load(f"{root_directory}/dptb/tests/data/test_sktb/output/test_push_w/checkpoint/nnsk.iter10.pth")
model_rs = torch.load(f"{root_directory}/dptb/tests/data/test_sktb/output/test_push_rs/checkpoint/nnsk.iter_rs2.650_w0.300.pth")
model_w = torch.load(f"{root_directory}/dptb/tests/data/test_sktb/output/test_push_w/checkpoint/nnsk.iter_rs5.000_w0.350.pth")
# test push limits
# 10 epoch, 0.01 step, 1 period -> 0.05 added.
assert np.isclose(model_rs["config"]["model_options"]["nnsk"]["hopping"]["rs"], 2.65)
Expand All @@ -58,7 +58,7 @@ def test_nnsk_push(root_directory):
def test_md(root_directory):
INPUT_file =root_directory + "/dptb/tests/data/test_sktb/input/input_md.json"
output = root_directory + "/dptb/tests/data/test_sktb/output"
init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_push_w/checkpoint/nnsk.iter10.pth"
init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_push_w/checkpoint/nnsk.iter_rs5.000_w0.350.pth"

check_config_train(INPUT=INPUT_file, init_model=None, restart=None, train_soc=False)
train(INPUT=INPUT_file, init_model=init_model, restart=None, train_soc=False,\
Expand All @@ -69,7 +69,7 @@ def test_md(root_directory):
def test_dptb(root_directory):
INPUT_file =root_directory + "/dptb/tests/data/test_sktb/input/input_dptb.json"
output = root_directory + "/dptb/tests/data/test_sktb/output"
init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_md/checkpoint/nnsk.ep2.pth"
init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_md/checkpoint/nnsk.latest.pth"

check_config_train(INPUT=INPUT_file, init_model=None, restart=None, train_soc=False)
train(INPUT=INPUT_file, init_model=init_model, restart=None, train_soc=False,\
Expand Down
4 changes: 3 additions & 1 deletion dptb/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def train_options():
- `LBFGS`: [On the limited memory BFGS method for large scale optimization.](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited-memory.pdf) \n\n\
"
doc_lr_scheduler = "The learning rate scheduler tools settings, the lr scheduler is used to scales down the learning rate during the training process. Proper setting can make the training more stable and efficient. The supported lr schedular includes: `Exponential Decaying (exp)`, `Linear multiplication (linear)`"
doc_batch_size = ""
doc_batch_size = "The batch size used in training, Default: 1"
doc_max_ckpt = "The maximum number of saved checkpoints, Default: 4"

args = [
Argument("num_epoch", int, optional=False, doc=doc_num_epoch),
Expand All @@ -65,6 +66,7 @@ def train_options():
Argument("save_freq", int, optional=True, default=10, doc=doc_save_freq),
Argument("validation_freq", int, optional=True, default=10, doc=doc_validation_freq),
Argument("display_freq", int, optional=True, default=1, doc=doc_display_freq),
Argument("max_ckpt", int, optional=True, default=4, doc=doc_max_ckpt),
loss_options()
]

Expand Down
3 changes: 2 additions & 1 deletion dptb/utils/config_sk.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
"validation_freq": 10,
"display_freq": 100,
"ref_batch_size": 1,
"val_batch_size": 1
"val_batch_size": 1,
"max_ckpt":4
},
"model_options": {
"nnsk": {
Expand Down
3 changes: 2 additions & 1 deletion dptb/utils/config_skenv.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@
"validation_freq": 10,
"display_freq": 100,
"ref_batch_size": 1,
"val_batch_size": 1
"val_batch_size": 1,
"max_ckpt":4
},
"model_options": {
"embedding": {
Expand Down

0 comments on commit 16f9e20

Please sign in to comment.