Skip to content

Commit

Permalink
adding save configure
Browse files Browse the repository at this point in the history
  • Loading branch information
[email protected] committed Aug 22, 2019
1 parent 402ecd6 commit f1d0237
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions code/main-multislot.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,7 @@ def main():
if args.fp16:
model.half()
model.to(device)
save_configure(args, num_labels, processor.ontology)

## Get slot-value embeddings
label_token_ids, label_len = [], []
Expand Down Expand Up @@ -1014,6 +1015,23 @@ def _eval_acc(_pred_slot, _labels):

return accuracies

def save_configure(args, num_labels, ontology):
with open(os.path.join(args.output_dir, "config.json"),'w') as outfile:
data = { "hidden_dim": args.hidden_dim,
"num_rnn_layers": args.num_rnn_layers,
"zero_init_rnn": args.zero_init_rnn,
"max_seq_length": args.max_seq_length,
"max_label_length": args.max_label_length,
"num_labels": num_labels,
"attn_head": args.attn_head,
"distance_metric": args.distance_metric,
"fix_utterance_encoder": args.fix_utterance_encoder,
"task_name": args.task_name,
"bert_dir": args.bert_dir,
"bert_model": args.bert_model,
"do_lower_case": args.do_lower_case,
"ontology": ontology}
json.dump(data, outfile, indent=4)

if __name__ == "__main__":
main()

0 comments on commit f1d0237

Please sign in to comment.