Skip to content

Commit

Permalink
feat: print num params and avg inference time;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Jun 1, 2024
1 parent 8ba3ab9 commit 0e36945
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions benchmark_code/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import argparse
import os
import time

import numpy as np
import torch
Expand Down Expand Up @@ -107,6 +108,7 @@
mae_collector = []
mse_collector = []
mre_collector = []
time_collector = []

result_saving_path = os.path.join(args.saving_path, f"{args.model}_{args.dataset}")
for n_round in range(args.n_rounds):
Expand All @@ -129,13 +131,15 @@
model.fit(train_set=train_set, val_set=val_set)

test_set = {"X": test_X}
start_time = time.time()
if args.model == "CSDI":
results = model.predict(test_set, n_sampling_times=10)
imputation = results["imputation"].mean(axis=1)
else:
results = model.predict(test_set)
imputation = results["imputation"]

time_collector.append(time.time() - start_time)
mae = calc_mae(imputation, test_X_ori, test_indicating_mask)
mse = calc_mse(imputation, test_X_ori, test_indicating_mask)
mre = calc_mre(imputation, test_X_ori, test_indicating_mask)
Expand All @@ -161,10 +165,12 @@
np.std(mse_collector),
np.std(mre_collector),
)
num_params = sum(p.numel() for p in model.model.parameters() if p.requires_grad)
logger.info(
f"Done! Final results:\n"
f"Averaged {args.model} on {args.dataset}: "
f"Averaged {args.model} (n params: {num_params:,}) on {args.dataset}: "
f"MAE={mean_mae:.4f} ± {std_mae}, "
f"MSE={mean_mse:.4f} ± {std_mse}, "
f"MRE={mean_mre:.4f} ± {std_mre}"
f"MRE={mean_mre:.4f} ± {std_mre}, "
f"average inference time={np.mean(time_collector):.2f}"
)

0 comments on commit 0e36945

Please sign in to comment.