diff --git a/Diag_test.py b/Diag_test.py index 803b4f3..5caade4 100644 --- a/Diag_test.py +++ b/Diag_test.py @@ -42,6 +42,7 @@ def evaluate_diagNetwork(model, valid_dataloaders): val_simple_cnt += labels.size(0) y_val_true.extend(np.ravel(np.squeeze(labels.cpu().detach().numpy())).tolist()) y_val_pred.extend(np.ravel(np.squeeze(val_predicted.cpu().detach().numpy())).tolist()) + outputs=outputs.softmax(dim=-1) val_prob_all.extend(outputs[:, 1].cpu().detach().numpy()) val_label_all.extend(labels.cpu()) diff --git a/utils/Diag_pretraining.py b/utils/Diag_pretraining.py index 4c2e790..5c00856 100644 --- a/utils/Diag_pretraining.py +++ b/utils/Diag_pretraining.py @@ -57,6 +57,7 @@ def train_data(model, train_dataloaders, valid_dataloaders, epochs, optimizer, s train_simple_cnt += labels.size(0) y_train_true.extend(np.ravel(np.squeeze(labels.cpu().detach().numpy())).tolist()) y_train_pred.extend(np.ravel(np.squeeze(train_predicted.cpu().detach().numpy())).tolist()) + outputs = outputs.softmax(dim=-1) train_prob_all.extend(outputs[:, 1].cpu().detach().numpy()) train_label_all.extend(labels.cpu()) @@ -80,6 +81,7 @@ def train_data(model, train_dataloaders, valid_dataloaders, epochs, optimizer, s val_simple_cnt += labels.size(0) y_val_true.extend(np.ravel(np.squeeze(labels.cpu().detach().numpy())).tolist()) y_val_pred.extend(np.ravel(np.squeeze(val_predicted.cpu().detach().numpy())).tolist()) + outputs = outputs.softmax(dim=-1) val_prob_all.extend(outputs[:, 1].cpu().detach().numpy()) val_label_all.extend(labels.cpu())