Skip to content

Commit

Permalink
update information.
Browse files Browse the repository at this point in the history
  • Loading branch information
thibault-wch committed Feb 5, 2025
1 parent 07d83d7 commit 7188b13
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 0 deletions.
1 change: 1 addition & 0 deletions Diag_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def evaluate_diagNetwork(model, valid_dataloaders):
val_simple_cnt += labels.size(0)
y_val_true.extend(np.ravel(np.squeeze(labels.cpu().detach().numpy())).tolist())
y_val_pred.extend(np.ravel(np.squeeze(val_predicted.cpu().detach().numpy())).tolist())
outputs=outputs.softmax(dim=-1)
val_prob_all.extend(outputs[:,
1].cpu().detach().numpy())
val_label_all.extend(labels.cpu())
Expand Down
2 changes: 2 additions & 0 deletions utils/Diag_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def train_data(model, train_dataloaders, valid_dataloaders, epochs, optimizer, s
train_simple_cnt += labels.size(0)
y_train_true.extend(np.ravel(np.squeeze(labels.cpu().detach().numpy())).tolist())
y_train_pred.extend(np.ravel(np.squeeze(train_predicted.cpu().detach().numpy())).tolist())
outputs = outputs.softmax(dim=-1)
train_prob_all.extend(outputs[:,
1].cpu().detach().numpy())
train_label_all.extend(labels.cpu())
Expand All @@ -80,6 +81,7 @@ def train_data(model, train_dataloaders, valid_dataloaders, epochs, optimizer, s
val_simple_cnt += labels.size(0)
y_val_true.extend(np.ravel(np.squeeze(labels.cpu().detach().numpy())).tolist())
y_val_pred.extend(np.ravel(np.squeeze(val_predicted.cpu().detach().numpy())).tolist())
outputs = outputs.softmax(dim=-1)
val_prob_all.extend(outputs[:,
1].cpu().detach().numpy())
val_label_all.extend(labels.cpu())
Expand Down

0 comments on commit 7188b13

Please sign in to comment.