Skip to content

Commit

Permalink
Perf attack save
Browse files Browse the repository at this point in the history
  • Loading branch information
klchen0112 committed Dec 15, 2023
1 parent 936e86d commit c8bd563
Showing 1 changed file with 76 additions and 35 deletions.
111 changes: 76 additions & 35 deletions torchattacks/attack.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,7 @@ def save(
self,
data_loader,
save_path=None,
save_every_iter=False,
verbose=True,
return_verbose=False,
save_predictions=False,
Expand All @@ -260,6 +261,7 @@ def save(
Arguments:
save_path (str): save_path.
data_loader (torch.utils.data.DataLoader): data loader.
save_every_iter (bool): True for save every results every iter. (Default: False)
verbose (bool): True for displaying detailed information. (Default: True)
return_verbose (bool): True for returning detailed information. (Default: False)
save_predictions (bool): True for saving predicted labels (Default: False)
Expand Down Expand Up @@ -319,45 +321,33 @@ def save(
if save_path is not None:
adv_input_list.append(adv_inputs.detach().cpu())
label_list.append(labels.detach().cpu())

adv_input_list_cat = torch.cat(adv_input_list, 0)
label_list_cat = torch.cat(label_list, 0)

save_dict = {
"adv_inputs": adv_input_list_cat,
"labels": label_list_cat,
} # nopep8

if save_predictions:
pred_list.append(pred.detach().cpu())
pred_list_cat = torch.cat(pred_list, 0)
save_dict["preds"] = pred_list_cat

if save_clean_inputs:
input_list.append(inputs.detach().cpu())
input_list_cat = torch.cat(input_list, 0)
save_dict["clean_inputs"] = input_list_cat

if self.normalization_used is not None:
save_dict["adv_inputs"] = self.inverse_normalize(
save_dict["adv_inputs"]
) # nopep8
if save_clean_inputs:
save_dict["clean_inputs"] = self.inverse_normalize(
save_dict["clean_inputs"]
) # nopep8

if save_type == "int":
save_dict["adv_inputs"] = self.to_type(
save_dict["adv_inputs"], "int"
) # nopep8
if save_clean_inputs:
save_dict["clean_inputs"] = self.to_type(
save_dict["clean_inputs"], "int"
) # nopep8

save_dict["save_type"] = save_type
torch.save(save_dict, save_path)
if save_every_iter:
self._save_adv_examples(
save_type,
save_path,
adv_input_list,
label_list,
save_predictions = save_predictions,
pred_list = pred_list if save_predictions else None,
save_clean_inputs = save_clean_inputs,
input_list = input_list if save_clean_inputs else None,
)

if save_path is not None and not save_every_iter:
self._save_adv_examples(
save_type,
save_path,
adv_input_list,
label_list,
save_predictions = save_predictions,
pred_list = pred_list if save_predictions else None,
save_clean_inputs = save_clean_inputs,
input_list = input_list if save_clean_inputs else None,
)

# To avoid erasing the printed information.
if verbose:
Expand Down Expand Up @@ -388,6 +378,57 @@ def to_type(inputs, type):
raise ValueError(type + " is not a valid type. [Options: float, int]")
return inputs


def _save_adv_examples(
self,
save_type,
save_path,
adv_input_list,
label_list,
save_predictions = False,
pred_list = [],
save_clean_inputs = False,
input_list = [],
):


adv_input_list_cat = torch.cat(adv_input_list, 0)
label_list_cat = torch.cat(label_list, 0)

save_dict = {
"adv_inputs": adv_input_list_cat,
"labels": label_list_cat,
}

if save_predictions:
pred_list_cat = torch.cat(pred_list, 0)
save_dict["preds"] = pred_list_cat

if save_clean_inputs:
input_list_cat = torch.cat(input_list, 0)
save_dict["clean_inputs"] = input_list_cat

if self.normalization_used is not None:
save_dict["adv_inputs"] = self.inverse_normalize(
save_dict["adv_inputs"]
) # nopep8
if save_clean_inputs:
save_dict["clean_inputs"] = self.inverse_normalize(
save_dict["clean_inputs"]
) # nopep8

if save_type == "int":
save_dict["adv_inputs"] = self.to_type(
save_dict["adv_inputs"], "int"
) # nopep8
if save_clean_inputs:
save_dict["clean_inputs"] = self.to_type(
save_dict["clean_inputs"], "int"
) # nopep8

save_dict["save_type"] = save_type
torch.save(save_dict, save_path)

@staticmethod
def _save_print(progress, rob_acc, l2, elapsed_time, end):
print(
Expand Down

0 comments on commit c8bd563

Please sign in to comment.