-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun-text.py
53 lines (42 loc) · 1.9 KB
/
run-text.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from pytorch_lightning.cli import LightningCLI
from pytorch_lightning.callbacks import ModelCheckpoint, ModelSummary
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from src.models.parent import ParentModule
from src.data.data_module import S2TSumDataMod
class T2TSumLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments('model.init_args.batch_size', 'data.batch_size', apply_on='parse')
parser.link_arguments('model.accumulate_grad_batches', 'trainer.accumulate_grad_batches', apply_on='parse')
parser.add_lightning_class_args(EarlyStopping, 'early_stopping')
parser.set_defaults({
'early_stopping.monitor': 'val/loss',
'early_stopping.mode': 'min',
'early_stopping.patience': 20
})
parser.add_lightning_class_args(ModelCheckpoint, 'val_loss_checkpoint')
parser.set_defaults({
'val_loss_checkpoint.monitor': 'val/loss',
'val_loss_checkpoint.mode': 'min',
'val_loss_checkpoint.save_top_k': 1,
'val_loss_checkpoint.save_last': True,
'val_loss_checkpoint.filename': '{epoch}-{step}'
})
parser.add_lightning_class_args(ModelCheckpoint, 'rouge2_checkpoint')
parser.set_defaults({
'rouge2_checkpoint.monitor': 'metric/rouge2',
'rouge2_checkpoint.mode': 'max',
'rouge2_checkpoint.save_top_k': 1,
'rouge2_checkpoint.save_last': False,
'rouge2_checkpoint.filename': 'rouge2-{epoch}-{step}'
})
parser.add_lightning_class_args(ModelSummary, 'model_summary')
parser.set_defaults({'model_summary.max_depth': -1})
def cli_main():
cli = T2TSumLightningCLI(
ParentModule,
S2TSumDataMod,
subclass_mode_model = True,
save_config_overwrite = True
)
if __name__ == '__main__':
cli_main()