diff --git a/.gitignore b/.gitignore index 1375abfe..b5f10876 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ dptb/tests/**/*.pth dptb/tests/**/*.npy dptb/tests/**/*.traj dptb/tests/**/out*/* +examples/_* *.dat *.vasp *log* diff --git a/dptb/nn/build.py b/dptb/nn/build.py index 7323340d..ea90357d 100644 --- a/dptb/nn/build.py +++ b/dptb/nn/build.py @@ -133,6 +133,9 @@ def build_model(run_options, model_options: dict={}, common_options: dict={}, st model = NNSK.from_reference(checkpoint, **model_options["nnsk"], **common_options) if init_mixed: # mix model can be initilized with a mixed reference model or a nnsk model. - model = MIX.from_reference(checkpoint, **model_options, **common_options) + model = MIX.from_reference(checkpoint, **model_options, **common_options) + if model.model_options != model_options: + # log.error("The model options are not consistent with the checkpoint, using the one in the checkpoint.") + raise ValueError("The model options are not consistent with the checkpoint, using the one in the checkpoint.") return model diff --git a/dptb/nn/deeptb.py b/dptb/nn/deeptb.py index 3a82dd2c..a3290bbf 100644 --- a/dptb/nn/deeptb.py +++ b/dptb/nn/deeptb.py @@ -10,7 +10,9 @@ from dptb.nn.nnsk import NNSK from e3nn.o3 import Linear from dptb.nn.rescale import E3PerSpeciesScaleShift, E3PerEdgeSpeciesScaleShift +import logging +log = logging.getLogger(__name__) """ if this class is called, it suggest user choose a embedding method. If not, it should directly use _sktb.py """ @@ -382,6 +384,10 @@ def from_reference( if len(nnsk) == 0: model_options["nnsk"] = ckpt["config"]["model_options"]["nnsk"] + if model_options["nnsk"].get("push") is not None: + model_options["nnsk"]["push"] = None + log.warning("The push option is not supported in the mixed model. The push option is only supported in the nnsk model.") + if len(embedding) == 0 or len(prediction) == 0: assert ckpt["config"]["model_options"].get("embedding") is not None and ckpt["config"]["model_options"].get("prediction") is not None, \ "The reference model checkpoint should come from a mixed model if dptb info is not provided." diff --git a/dptb/nn/nnsk.py b/dptb/nn/nnsk.py index fb07f7d6..da969039 100644 --- a/dptb/nn/nnsk.py +++ b/dptb/nn/nnsk.py @@ -62,7 +62,8 @@ def __init__( "onsite": onsite, "hopping": hopping, "freeze": freeze, - "push": push, + "push": push, + "std": std } } @@ -348,7 +349,7 @@ def from_reference( if v is None: common_options[k] = f["config"]["common_options"][k] for k,v in nnsk.items(): - if v is None: + if v is None and not k is "push" : nnsk[k] = f["config"]["model_options"]["nnsk"][k] model = cls(**common_options, **nnsk) diff --git a/dptb/plugins/saver.py b/dptb/plugins/saver.py index 36d5a64c..4d2c147a 100644 --- a/dptb/plugins/saver.py +++ b/dptb/plugins/saver.py @@ -21,13 +21,36 @@ def register(self, trainer, checkpoint_path): self.checkpoint_path = checkpoint_path self.trainer = trainer + if self.trainer.model.name == "nnsk": + # 获取 push 选项 + push_option = self.trainer.model.model_options["nnsk"].get("push", False) + if push_option: + # 计算所有阈值之和 + thrs = sum(abs(val) for key, val in push_option.items() if "thr" in key) + # 如果阈值之和不为 0, 则 push 为 True + push = abs(thrs) != 0.0 + else: + push = False + else: + push = False + self.push = push + + def iteration(self, **kwargs): # suffix = "_b"+"%.3f"%self.trainer.common_options["bond_cutoff"]+"_c"+"%.3f"%self.trainer.onsite_options["skfunction"]["sk_cutoff"]+"_w"+\ # "%.3f"%self.trainer.model_options["skfunction"]["sk_decay_w"] - suffix = ".iter{}".format(self.trainer.iter) + if self.push: + suffix = ".iter_rs" + "%.3f"%self.trainer.model.hopping_options["rs"]+"_w"+"%.3f"%self.trainer.model.hopping_options["w"] + # By default, the maximum number of saved checkpoints is 100 for pushing rs and w. + max_ckpt = 100 + else: + suffix = ".iter{}".format(self.trainer.iter) + max_ckpt = self.trainer.train_options["max_ckpt"] + name = self.trainer.model.name+suffix self.latest_quene.append(name) - if len(self.latest_quene) >= 5: + + if len(self.latest_quene) > max_ckpt: delete_name = self.latest_quene.pop(0) delete_path = os.path.join(self.checkpoint_path, delete_name+".pth") os.remove(delete_path) @@ -40,11 +63,17 @@ def iteration(self, **kwargs): train_options=self.trainer.train_options, ) - # if self.trainer.name == "dptb" \ - # and self.trainer.run_opt["use_correction"] \ - # and not self.trainer.run_opt["freeze"]: - - # self._save(name="latest_"+self.trainer.name+'_nnsk'+suffix,model=self.trainer.sknet, model_config=self.trainer.sknet_config) + if not self.push: + # 构建一个符号链接,指向最新的模型 + latest_symlink = os.path.join(self.checkpoint_path, self.trainer.model.name + ".latest.pth") + if os.path.lexists(latest_symlink): + os.unlink(latest_symlink) + latest_ckpt = os.path.join(self.checkpoint_path, name+".pth") + latest_ckpt_abs_path = os.path.abspath(latest_ckpt) + # 确保源文件存在 + if not os.path.exists(latest_ckpt_abs_path): + raise FileNotFoundError(f"Source file {latest_ckpt_abs_path} does not exist.") + os.symlink(latest_ckpt_abs_path, latest_symlink) def epoch(self, **kwargs): @@ -54,6 +83,7 @@ def epoch(self, **kwargs): else: updated_loss = self.trainer.stats.get("train_loss").get("epoch_mean",1e6) + max_ckpt = self.trainer.train_options["max_ckpt"] if updated_loss < self.best_loss: # suffix = "_b"+"%.3f"%self.trainer.common_options["bond_cutoff"]+"_c"+"%.3f"%self.trainer.model_options["skfunction"]["sk_cutoff"]+"_w"+\ @@ -61,7 +91,7 @@ def epoch(self, **kwargs): suffix = ".ep{}".format(self.trainer.ep) name = self.trainer.model.name+suffix self.best_quene.append(name) - if len(self.best_quene) >= 5: + if len(self.best_quene) > max_ckpt: delete_name = self.best_quene.pop(0) delete_path = os.path.join(self.checkpoint_path, delete_name+".pth") os.remove(delete_path) @@ -76,18 +106,16 @@ def epoch(self, **kwargs): self.best_loss = updated_loss - # if self.trainer.name == "dptb" \ - # and self.trainer.run_opt["use_correction"] \ - # and not self.trainer.run_opt["freeze"]: - - # self._save( - # name="best_"+self.trainer.name+'_nnsk'+suffix, - # model=self.trainer.sknet, - # model_config=self.trainer.sknet_config - # common_options=self.trainer.common_options - # ) - - # log.info(msg="checkpoint saved as {}".format("best_epoch")) + # 构建一个符号链接,指向best模型 + best_symlink = os.path.join(self.checkpoint_path, self.trainer.model.name + ".best.pth") + if os.path.lexists(best_symlink): + os.unlink(best_symlink) + best_ckpt = os.path.join(self.checkpoint_path, name+".pth") + best_ckpt_abs_path = os.path.abspath(best_ckpt) + # 确保源文件存在 + if not os.path.exists(best_ckpt_abs_path): + raise FileNotFoundError(f"Source file {best_ckpt_abs_path} does not exist.") + os.symlink(best_ckpt_abs_path, best_symlink) def _save(self, name, model, model_options, common_options, train_options): obj = {} diff --git a/dptb/tests/test_sktb.py b/dptb/tests/test_sktb.py index 6bf9ccb1..e5fc58e3 100644 --- a/dptb/tests/test_sktb.py +++ b/dptb/tests/test_sktb.py @@ -23,7 +23,7 @@ def test_nnsk_valence(root_directory): def test_nnsk_strain_polar(root_directory): INPUT_file = root_directory+"/dptb/tests/data/test_sktb/input/input_strain_polar.json" output = root_directory+"/dptb/tests/data/test_sktb/output" - init_model = root_directory+"/dptb/tests/data/test_sktb/output/test_valence/checkpoint/nnsk.iter2.pth" + init_model = root_directory+"/dptb/tests/data/test_sktb/output/test_valence/checkpoint/nnsk.latest.pth" check_config_train(INPUT=INPUT_file, init_model=None, restart=None, train_soc=False) train(INPUT=INPUT_file, init_model=init_model, restart=None, train_soc=False,\ @@ -36,7 +36,7 @@ def test_nnsk_push(root_directory): INPUT_file_rs = root_directory + "/dptb/tests/data/test_sktb/input/input_push_rs.json" INPUT_file_w = root_directory + "/dptb/tests/data/test_sktb/input/input_push_w.json" output = root_directory + "/dptb/tests/data/test_sktb/output" - init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_strain_polar/checkpoint/nnsk.ep2.pth" + init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_strain_polar/checkpoint/nnsk.best.pth" check_config_train(INPUT=INPUT_file_rs, init_model=None, restart=None, train_soc=False) train(INPUT=INPUT_file_rs, init_model=init_model, restart=None, train_soc=False,\ @@ -46,8 +46,8 @@ def test_nnsk_push(root_directory): train(INPUT=INPUT_file_w, init_model=init_model, restart=None, train_soc=False,\ output=output+"/test_push_w", log_level=5, log_path=output+"/test_push_w.log") - model_rs = torch.load(f"{root_directory}/dptb/tests/data/test_sktb/output/test_push_rs/checkpoint/nnsk.iter10.pth") - model_w = torch.load(f"{root_directory}/dptb/tests/data/test_sktb/output/test_push_w/checkpoint/nnsk.iter10.pth") + model_rs = torch.load(f"{root_directory}/dptb/tests/data/test_sktb/output/test_push_rs/checkpoint/nnsk.iter_rs2.650_w0.300.pth") + model_w = torch.load(f"{root_directory}/dptb/tests/data/test_sktb/output/test_push_w/checkpoint/nnsk.iter_rs5.000_w0.350.pth") # test push limits # 10 epoch, 0.01 step, 1 period -> 0.05 added. assert np.isclose(model_rs["config"]["model_options"]["nnsk"]["hopping"]["rs"], 2.65) @@ -58,7 +58,7 @@ def test_nnsk_push(root_directory): def test_md(root_directory): INPUT_file =root_directory + "/dptb/tests/data/test_sktb/input/input_md.json" output = root_directory + "/dptb/tests/data/test_sktb/output" - init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_push_w/checkpoint/nnsk.iter10.pth" + init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_push_w/checkpoint/nnsk.iter_rs5.000_w0.350.pth" check_config_train(INPUT=INPUT_file, init_model=None, restart=None, train_soc=False) train(INPUT=INPUT_file, init_model=init_model, restart=None, train_soc=False,\ @@ -69,7 +69,7 @@ def test_md(root_directory): def test_dptb(root_directory): INPUT_file =root_directory + "/dptb/tests/data/test_sktb/input/input_dptb.json" output = root_directory + "/dptb/tests/data/test_sktb/output" - init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_md/checkpoint/nnsk.ep2.pth" + init_model = root_directory + "/dptb/tests/data/test_sktb/output/test_md/checkpoint/nnsk.latest.pth" check_config_train(INPUT=INPUT_file, init_model=None, restart=None, train_soc=False) train(INPUT=INPUT_file, init_model=init_model, restart=None, train_soc=False,\ diff --git a/dptb/utils/argcheck.py b/dptb/utils/argcheck.py index 58dd1263..ab0bb228 100644 --- a/dptb/utils/argcheck.py +++ b/dptb/utils/argcheck.py @@ -53,7 +53,8 @@ def train_options(): - `LBFGS`: [On the limited memory BFGS method for large scale optimization.](http://users.iems.northwestern.edu/~nocedal/PDFfiles/limited-memory.pdf) \n\n\ " doc_lr_scheduler = "The learning rate scheduler tools settings, the lr scheduler is used to scales down the learning rate during the training process. Proper setting can make the training more stable and efficient. The supported lr schedular includes: `Exponential Decaying (exp)`, `Linear multiplication (linear)`" - doc_batch_size = "" + doc_batch_size = "The batch size used in training, Default: 1" + doc_max_ckpt = "The maximum number of saved checkpoints, Default: 4" args = [ Argument("num_epoch", int, optional=False, doc=doc_num_epoch), @@ -65,6 +66,7 @@ def train_options(): Argument("save_freq", int, optional=True, default=10, doc=doc_save_freq), Argument("validation_freq", int, optional=True, default=10, doc=doc_validation_freq), Argument("display_freq", int, optional=True, default=1, doc=doc_display_freq), + Argument("max_ckpt", int, optional=True, default=4, doc=doc_max_ckpt), loss_options() ] diff --git a/dptb/utils/config_sk.py b/dptb/utils/config_sk.py index 5d060ea0..2ef1878b 100644 --- a/dptb/utils/config_sk.py +++ b/dptb/utils/config_sk.py @@ -39,7 +39,8 @@ "validation_freq": 10, "display_freq": 100, "ref_batch_size": 1, - "val_batch_size": 1 + "val_batch_size": 1, + "max_ckpt":4 }, "model_options": { "nnsk": { diff --git a/dptb/utils/config_skenv.py b/dptb/utils/config_skenv.py index 20b3f508..03a84f7c 100644 --- a/dptb/utils/config_skenv.py +++ b/dptb/utils/config_skenv.py @@ -39,7 +39,8 @@ "validation_freq": 10, "display_freq": 100, "ref_batch_size": 1, - "val_batch_size": 1 + "val_batch_size": 1, + "max_ckpt":4 }, "model_options": { "embedding": {