Source code for ares.attack.cw

import numpy as np
import torch
from torch.autograd import Variable
from ares.utils.registry import registry

[docs]@registry.register_attack('cw') class CW(object): ''' Carlini & Wagner Attack (C&W). A white-box iterative optimization-based method. Require a differentiable logits. Example: >>> from ares.utils.registry import registry >>> attacker_cls = registry.get_attack('cw') >>> attacker = attacker_cls(model) >>> adv_images = attacker(images, labels, target_labels) - Supported distance metric: 2. - References: References: https://arxiv.org/pdf/1608.04644.pdf. '''
[docs] def __init__(self, model, device='cuda', norm=2, kappa=0, lr=0.2, init_const=0.01, max_iter=200, binary_search_steps=4, num_classes=1000, target=False): ''' Args: model (torch.nn.Module): The target model to be attacked. device (torch.device): The device to perform autoattack. Defaults to 'cuda'. norm (float): The norm of distance calculation for adversarial constraint. Defaults to 2. kappa (float): Defaults to 0. lr (float): The learning rate for attack process. init_const (float): The initialized constant. max_iter (int): The maximum iteration. binary_search_steps (int): The steps for binary search. num_classes (int): The number of classes of all the labels. target (bool): Conduct target/untarget attack. Defaults to False. ''' self.net = model self.device = device self.IsTargeted = target self.kappa = kappa self.learning_rate = lr self.init_const = init_const self.lower_bound = 0.0 self.upper_bound = 1.0 self.max_iter = max_iter self.norm = norm self.binary_search_steps = binary_search_steps self.class_type_number = num_classes assert self.norm == 2, 'curreent cw only support l_2'
[docs] def atanh(self, x): return 0.5 * torch.log((1 + x) / (1 - x))
def __call__(self, images=None, labels=None, target_labels=None): '''This function perform attack on target images with corresponding labels and target labels for target attack. Args: images (torch.Tensor): The images to be attacked. The images should be torch.Tensor with shape [N, C, H, W] and range [0, 1]. labels (torch.Tensor): The corresponding labels of the images. The labels should be torch.Tensor with shape [N, ] target_labels (torch.Tensor): The target labels for target attack. The labels should be torch.Tensor with shape [N, ] Returns: torch.Tensor: Adversarial images with value range [0,1]. ''' device = self.device targeted = self.IsTargeted copy_images = images.clone() copy_labels = labels.clone() if target_labels is not None: copy_target_labels = target_labels.clone() else: copy_target_labels = copy_labels batch_size = images.shape[0] mid_point = (self.upper_bound + self.lower_bound) * 0.5 half_range = (self.upper_bound - self.lower_bound) * 0.5 arctanh_images = self.atanh((copy_images - mid_point) / half_range * 0.9999) var_images=arctanh_images.clone() var_images.requires_grad=True const_origin = torch.ones(batch_size, device=self.device) * self.init_const c_upper_bound = [1e10] * batch_size c_lower_bound = torch.zeros(batch_size, device=self.device) targets_in_one_hot = [] targeteg_class_in_one_hot = [] temp_one_hot_matrix = torch.eye(int(self.class_type_number), device=self.device) if targeted: for i in range(batch_size): current_target1 = temp_one_hot_matrix[copy_target_labels[i]] targeteg_class_in_one_hot.append(current_target1) targeteg_class_in_one_hot = torch.stack(targeteg_class_in_one_hot).clone().type_as(images).to(self.device) #torch.Size([10, 10]) else: for i in range(batch_size): current_target = temp_one_hot_matrix[copy_labels[i]] targets_in_one_hot.append(current_target) targets_in_one_hot = torch.stack(targets_in_one_hot).clone().type_as(images).to(self.device) #torch.Size([10, 10]) best_l2 = [1e10] * batch_size best_perturbation = torch.zeros(var_images.size()) current_prediction_class = [-1] * batch_size def attack_achieved(pre_softmax, true_class, target_class): targeted = self.IsTargeted if targeted: pre_softmax[target_class] -= self.kappa return torch.argmax(pre_softmax).item() == target_class else: pre_softmax[true_class] -= self.kappa return torch.argmax(pre_softmax).item() != true_class for search_for_c in range(self.binary_search_steps): modifier = torch.zeros(var_images.shape).float() modifier = Variable(modifier.to(device), requires_grad=True) optimizer = torch.optim.Adam([modifier], lr=self.learning_rate) var_const = const_origin.clone().to(device) # print("\tbinary search step {}:".format(search_for_c)) for iteration_times in range(self.max_iter): # inverse the transform tanh -> [0, 1] perturbed_images = (torch.tanh(var_images + modifier) * half_range + mid_point) prediction = self.net(perturbed_images) l2dist = torch.sum( (perturbed_images - (torch.tanh(var_images) * half_range + mid_point)) ** 2, [1, 2, 3], ) if targeted: constraint_loss = torch.max((prediction - 1e10 * targeteg_class_in_one_hot).max(1)[0] - (prediction * targeteg_class_in_one_hot).sum(1), torch.ones(batch_size, device=device) * self.kappa * -1, ) else: constraint_loss = torch.max((prediction * targets_in_one_hot).sum(1) - (prediction - 1e10 * targets_in_one_hot).max(1)[0], torch.ones(batch_size, device=device) * self.kappa * -1, ) loss_f = var_const * constraint_loss loss = l2dist.sum() + loss_f.sum() # minimize |r| + c * loss_f(x+r,l) optimizer.zero_grad() loss.backward(retain_graph=True) optimizer.step() for i in range(prediction.shape[0]): dist=l2dist[i] score=prediction[i] img=perturbed_images[i] if dist.item() < best_l2[i] and attack_achieved(score, copy_labels[i], copy_target_labels[i]): best_l2[i] = dist current_prediction_class[i] = torch.argmax(score) best_perturbation[i] = img # update the best constant c for each sample in the batch for i in range(batch_size): if ( current_prediction_class[i] == copy_labels[i].item() and current_prediction_class[i] != -1 ): c_upper_bound[i] = min(c_upper_bound[i], const_origin[i].item()) if c_upper_bound[i] < 1e10: const_origin[i] = (c_lower_bound[i].item() + c_upper_bound[i]) / 2.0 else: c_lower_bound[i] = max(c_lower_bound[i].item(), const_origin[i].item()) if c_upper_bound[i] < 1e10: const_origin = (c_lower_bound[i].item() + c_upper_bound[i]) / 2.0 else: const_origin[i] *= 10 adv_images = best_perturbation.to(device) return adv_images