Source code for ares.attack.autoattack.autoattack

import math
import time
import numpy as np
import torch
from ares.attack.autoattack import checks
from ares.utils.registry import registry

[docs]@registry.register_attack('autoattack') class AutoAttack(): '''A class to perform autoattack. It is called by registry. Example: >>> from ares.utils.registry import registry >>> attacker_cls = registry.get_attack('autoattack') '''
[docs] def __init__(self, model, device='cuda', norm=np.inf, eps=.3, seed=None, verbose=False, attacks_to_run=[], version='standard', is_tf_model=False, logger=None): ''' Args: model (torch.nn.Module): The target model to be attacked. device (torch.device): The device to perform autoattack. Defaults to 'cuda'. norm (float): The norm of distance calculation for adversarial constraint. Defaults to np.inf. It is selected from [1, 2, np.inf]. eps (float): The maximum perturbation range epsilon. seed (float): Random seed. Defaults to None. verbose (bool): Output the details during the attack process. Defaults to True. attacks_to_run (list): Set the attacks to run. Defaults to []. It should be selected from ['apgd-ce', 'apgd-dlr', 'fab', 'square', 'apgd-t', 'fab-t']. version (str): Define the version of attack. Defaults to 'standard'. It is selected from ['standard', 'plus', 'rand']. is_tf_model (bool): Whether the model is based on tensorflow. Defaults to False. log_path (str): Path to the log file. Defaults to None. ''' self.model = model self.norm = None assert norm in [1, 2, np.inf] if norm == 1: self.norm = 'L1' elif norm == 2: self.norm = 'L2' else: self.norm = 'Linf' self.epsilon = eps self.seed = seed self.verbose = verbose self.attacks_to_run = attacks_to_run self.version = version self.is_tf_model = is_tf_model self.device = device self.logger = logger if self.verbose: assert self.logger is not None, "Must set logger, if verbose is True." assert not self.is_tf_model, "Only pytorch models supported." if version in ['standard', 'plus', 'rand'] and attacks_to_run != []: raise ValueError("attacks_to_run will be overridden unless you use version='custom'") from .autopgd_base import APGDAttack self.apgd = APGDAttack(self.model, n_restarts=5, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device, logger=self.logger) from .fab_pt import FABAttack_PT self.fab = FABAttack_PT(self.model, n_restarts=5, n_iter=100, eps=self.epsilon, seed=self.seed, norm=self.norm, verbose=False, device=self.device) from .square import SquareAttack self.square = SquareAttack(self.model, p_init=.8, n_queries=5000, eps=self.epsilon, norm=self.norm, n_restarts=1, seed=self.seed, verbose=False, device=self.device, resc_schedule=False) from .autopgd_base import APGDAttack_targeted self.apgd_targeted = APGDAttack_targeted(self.model, n_restarts=1, n_iter=100, verbose=False, eps=self.epsilon, norm=self.norm, eot_iter=1, rho=.75, seed=self.seed, device=self.device, logger=self.logger) if version in ['standard', 'plus', 'rand']: self.set_version(version)
[docs] def get_logits(self, x): '''This function calculates the logits of the target model.''' return self.model(x)
[docs] def get_seed(self): '''This function automatically set a random seed.''' return time.time() if self.seed is None else self.seed
def __call__(self, images=None, labels=None, target_labels=None): '''This function perform attack on target images with corresponding labels. Args: images (torch.Tensor): The images to be attacked. The images should be torch.Tensor with shape [N, C, H, W] and range [0, 1]. labels (torch.Tensor): The corresponding labels of the images. The labels should be torch.Tensor with shape [N, ] target_labels (torch.Tensor): Not used in autoattack and should be None type. ''' assert target_labels is None, "Target attack is not necessary for autoattack." x_adv = self.run_standard_evaluation(images, labels, bs=images.size(0), return_labels=False)
[docs] def run_standard_evaluation(self, images, labels, bs=250, return_labels=False): if self.verbose: print('using {} version including {}'.format(self.version, ', '.join(self.attacks_to_run))) # checks on type of defense if self.version != 'rand': checks.check_randomized(self.get_logits, images[:bs].to(self.device), labels[:bs].to(self.device), bs=bs, logger=self.logger) n_cls = checks.check_range_output(self.get_logits, images[:bs].to(self.device), logger=self.logger) checks.check_dynamic(self.model, images[:bs].to(self.device), self.is_tf_model, logger=self.logger) checks.check_n_classes(n_cls, self.attacks_to_run, self.apgd_targeted.n_target_classes, self.fab.n_target_classes, logger=self.logger) with torch.no_grad(): # calculate accuracy n_batches = int(np.ceil(images.shape[0] / bs)) robust_flags = torch.zeros(images.shape[0], dtype=torch.bool, device=images.device) y_adv = torch.empty_like(labels) for batch_idx in range(n_batches): start_idx = batch_idx * bs end_idx = min( (batch_idx + 1) * bs, images.shape[0]) x = images[start_idx:end_idx, :].clone().to(self.device) y = labels[start_idx:end_idx].clone().to(self.device) output = self.get_logits(x).max(dim=1)[1] y_adv[start_idx: end_idx] = output correct_batch = y.eq(output) robust_flags[start_idx:end_idx] = correct_batch.detach().to(robust_flags.device) robust_accuracy = torch.sum(robust_flags).item() / images.shape[0] robust_accuracy_dict = {'clean': robust_accuracy} if self.verbose:'initial accuracy: {:.2%}'.format(robust_accuracy)) x_adv = images.clone().detach() startt = time.time() for attack in self.attacks_to_run: # item() is super important as pytorch int division uses floor rounding num_robust = torch.sum(robust_flags).item() if num_robust == 0: break n_batches = int(np.ceil(num_robust / bs)) robust_lin_idcs = torch.nonzero(robust_flags, as_tuple=False) if num_robust > 1: robust_lin_idcs.squeeze_() for batch_idx in range(n_batches): start_idx = batch_idx * bs end_idx = min((batch_idx + 1) * bs, num_robust) batch_datapoint_idcs = robust_lin_idcs[start_idx:end_idx] if len(batch_datapoint_idcs.shape) > 1: batch_datapoint_idcs.squeeze_(-1) x = images[batch_datapoint_idcs, :].clone().to(self.device) y = labels[batch_datapoint_idcs].clone().to(self.device) # make sure that x is a 4d tensor even if there is only a single datapoint left if len(x.shape) == 3: x.unsqueeze_(dim=0) # run attack if attack == 'apgd-ce': # apgd on cross-entropy loss self.apgd.loss = 'ce' self.apgd.seed = self.get_seed() adv_curr = self.apgd.perturb(x, y) #cheap=True elif attack == 'apgd-dlr': # apgd on dlr loss self.apgd.loss = 'dlr' self.apgd.seed = self.get_seed() adv_curr = self.apgd.perturb(x, y) #cheap=True elif attack == 'fab': # fab self.fab.targeted = False self.fab.seed = self.get_seed() adv_curr = self.fab.perturb(x, y) elif attack == 'square': # square self.square.seed = self.get_seed() adv_curr = self.square.perturb(x, y) elif attack == 'apgd-t': # targeted apgd self.apgd_targeted.seed = self.get_seed() adv_curr = self.apgd_targeted.perturb(x, y) #cheap=True elif attack == 'fab-t': # fab targeted self.fab.targeted = True self.fab.n_restarts = 1 self.fab.seed = self.get_seed() adv_curr = self.fab.perturb(x, y) else: raise ValueError('Attack not supported') output = self.get_logits(adv_curr).max(dim=1)[1] false_batch = ~y.eq(output).to(robust_flags.device) non_robust_lin_idcs = batch_datapoint_idcs[false_batch] robust_flags[non_robust_lin_idcs] = False x_adv[non_robust_lin_idcs] = adv_curr[false_batch].detach().to(x_adv.device) y_adv[non_robust_lin_idcs] = output[false_batch].detach().to(x_adv.device) if self.verbose: num_non_robust_batch = torch.sum(false_batch)'{} - {}/{} - {} out of {} successfully perturbed'.format( attack, batch_idx + 1, n_batches, num_non_robust_batch, x.shape[0])) robust_accuracy = torch.sum(robust_flags).item() / images.shape[0] robust_accuracy_dict[attack] = robust_accuracy if self.verbose:'robust accuracy after {}: {:.2%} (total time {:.1f} s)'.format( attack.upper(), robust_accuracy, time.time() - startt)) # check about square checks.check_square_sr(robust_accuracy_dict, logger=self.logger) # final check if self.verbose: if self.norm == 'Linf': res = (x_adv - images).abs().reshape(images.shape[0], -1).max(1)[0] elif self.norm == 'L2': res = ((x_adv - images) ** 2).reshape(images.shape[0], -1).sum(-1).sqrt() elif self.norm == 'L1': res = (x_adv - images).abs().reshape(images.shape[0], -1).sum(dim=-1)'max {} perturbation: {:.5f}, nan in tensor: {}, max: {:.5f}, min: {:.5f}'.format( self.norm, res.max(), (x_adv != x_adv).sum(), x_adv.max(), x_adv.min()))'robust accuracy: {:.2%}'.format(robust_accuracy)) if return_labels: return x_adv, y_adv else: return x_adv
[docs] def clean_accuracy(self, images, labels, bs=250): n_batches = math.ceil(images.shape[0] / bs) acc = 0. for counter in range(n_batches): x = images[counter * bs:min((counter + 1) * bs, images.shape[0])].clone().to(self.device) y = labels[counter * bs:min((counter + 1) * bs, images.shape[0])].clone().to(self.device) output = self.get_logits(x) acc += (output.max(1)[1] == y).float().sum() if self.verbose: print('clean accuracy: {:.2%}'.format(acc / images.shape[0])) return acc.item() / images.shape[0]
[docs] def run_standard_evaluation_individual(self, images, labels, bs=250, return_labels=False): if self.verbose: print('using {} version including {}'.format(self.version, ', '.join(self.attacks_to_run))) l_attacks = self.attacks_to_run adv = {} verbose_indiv = self.verbose self.verbose = False for c in l_attacks: startt = time.time() self.attacks_to_run = [c] x_adv, y_adv = self.run_standard_evaluation(images, labels, bs=bs, return_labels=True) if return_labels: adv[c] = (x_adv, y_adv) else: adv[c] = x_adv if verbose_indiv: acc_indiv = self.clean_accuracy(x_adv, labels, bs=bs) space = '\t \t' if c == 'fab' else '\t''robust accuracy by {} {} {:.2%} \t (time attack: {:.1f} s)'.format( c.upper(), space, acc_indiv, time.time() - startt)) return adv
[docs] def set_version(self, version='standard'): '''The function to set the attack version. Args: version (str): The version of attack. Defaults to 'standard'. ''' if self.verbose: print('setting parameters for {} version'.format(version)) if version == 'standard': self.attacks_to_run = ['apgd-ce', 'apgd-t', 'fab-t', 'square'] if self.norm in ['Linf', 'L2']: self.apgd.n_restarts = 1 self.apgd_targeted.n_target_classes = 9 elif self.norm in ['L1']: self.apgd.use_largereps = True self.apgd_targeted.use_largereps = True self.apgd.n_restarts = 5 self.apgd_targeted.n_target_classes = 5 self.fab.n_restarts = 1 self.apgd_targeted.n_restarts = 1 self.fab.n_target_classes = 9 #self.apgd_targeted.n_target_classes = 9 self.square.n_queries = 5000 elif version == 'plus': self.attacks_to_run = ['apgd-ce', 'apgd-dlr', 'fab', 'square', 'apgd-t', 'fab-t'] self.apgd.n_restarts = 5 self.fab.n_restarts = 5 self.apgd_targeted.n_restarts = 1 self.fab.n_target_classes = 9 self.apgd_targeted.n_target_classes = 9 self.square.n_queries = 5000 if not self.norm in ['Linf', 'L2']: print('"{}" version is used with {} norm: please check'.format( version, self.norm)) elif version == 'rand': self.attacks_to_run = ['apgd-ce', 'apgd-dlr'] self.apgd.n_restarts = 1 self.apgd.eot_iter = 20