Source code for ares.attack.nes

import torch
import numpy as np
from ares.utils.registry import registry

[docs]@registry.register_attack('nes') class NES(object): ''' Natural Evolution Strategies (NES). A black-box constraint-based method. Use NES as gradient estimation technique and employ PGD with this estimated gradient to generate the adversarial example. Example: >>> from ares.utils.registry import registry >>> attacker_cls = registry.get_attack('nes') >>> attacker = attacker_cls(model) >>> adv_images = attacker(images, labels, target_labels) - Supported distance metric: 1, 2, np.inf. - References: 1. https://arxiv.org/abs/1804.08598. 2. http://www.jmlr.org/papers/volume15/wierstra14a/wierstra14a.pdf. '''
[docs] def __init__(self, model, device='cuda', norm=np.inf, eps=4/255, stepsize=1/255, nes_samples=10, sample_per_draw=1, max_queries=1000, search_sigma=0.02, decay=0.00, random_perturb_start=False, target=False): '''The initialize function for NES. Args: model (torch.nn.Module): The target model to be attacked. device (torch.device): The device to perform autoattack. Defaults to 'cuda'. norm (float): The norm of distance calculation for adversarial constraint. Defaults to np.inf. eps (float): The maximum perturbation range epsilon. stepsize (float): The step size for each attack iteration. Defaults to 1/255. nes_samples (int): The samples for NES. sample_per_draw (int): Sample in each draw. max_queries (int): Maximum query number. search_sigma (float): The sigma param for searching. decay (float): Decay rate. random_perturb_start (bool): Whether start with random perturbation. target (bool): Conduct target/untarget attack. Defaults to False. ''' self.model = model self.p = norm self.epsilon = eps self.step_size = stepsize self.max_queries = max_queries self.device = device self.search_sigma = search_sigma nes_samples = nes_samples if nes_samples else sample_per_draw self.nes_samples = (nes_samples // 2) *2 self.sample_per_draw = (sample_per_draw // self.nes_samples) * self.nes_samples self.nes_iters = self.sample_per_draw // self.nes_samples self.decay = decay self.random_perturb_start = random_perturb_start self.target = target self.min_value = 0 self.max_value = 1
def _is_adversarial(self,x, y, y_target): '''The function to judge if the input image is adversarial.''' output = torch.argmax(self.model(x), dim=1) if self.target: return output == y_target else: return output != y def _margin_logit_loss(self, x, labels, target_labels): '''The function to calculate the marginal logits.''' outputs = self.model(x) if self.target: one_hot_labels = torch.eye(len(outputs[0]))[target_labels].to(self.device) i, _ = torch.max((1-one_hot_labels)*outputs, dim=1) j = torch.masked_select(outputs, one_hot_labels.bool()) cost = -torch.clamp((i-j), min=0) # -self.kappa=0 else: one_hot_labels = torch.eye(len(outputs[0]))[labels].to(self.device) i, _ = torch.max((1-one_hot_labels)*outputs, dim=1) j = torch.masked_select(outputs, one_hot_labels.bool()) cost = -torch.clamp((j-i), min=0) # -self.kappa=0 return cost
[docs] def clip_eta(self, batchsize, eta, norm, eps): '''The function to clip image according to the constraint.''' if norm == np.inf: eta = torch.clamp(eta, -eps, eps) elif norm == 2: normVal = torch.norm(eta.view(batchsize, -1), self.p, 1) mask = normVal<=eps scaling = eps/normVal scaling[mask] = 1 eta = eta*scaling.view(batchsize, 1, 1, 1) else: raise NotImplementedError return eta
[docs] def nes_gradient(self, x, y, ytarget): '''The function to calculate the gradient of NES.''' x_shape = x.size() g = torch.zeros(x_shape).to(self.device) mean = torch.zeros(x_shape).to(self.device) std = torch.ones(x_shape).to(self.device) for i in range(self.nes_iters): u = torch.normal(mean, std).to(self.device) pred = self._margin_logit_loss(torch.clamp(x+self.search_sigma*u,self.min_value,self.max_value),y,ytarget) g = g + pred*u pred = self._margin_logit_loss(torch.clamp(x-self.search_sigma*u,self.min_value,self.max_value), y,ytarget) g = g - pred*u return g/(2*self.nes_iters*self.search_sigma)
[docs] def nes(self, x_victim, y_victim, y_target): '''The attack process of NES.''' batchsize = x_victim.shape[0] with torch.no_grad(): self.model.eval() x_victim = x_victim.to(self.device) y_victim = y_victim.to(self.device) if y_target is not None: y_target = y_target.to(self.device) self.model.to(self.device) if self._is_adversarial(x_victim, y_victim, y_target): self.detail['queries'] = 0 self.detail['success'] = True return x_victim self.detail['success'] = False queries = 0 x_adv = x_victim.clone().to(self.device) if self.random_perturb_start: noise = torch.rand(x_adv.size()).to(self.device) normalized_noise = self.clip_eta(batchsize, noise, self.p, self.epsilon) x_adv += normalized_noise momentum = torch.zeros_like(x_adv) self.model.eval() while queries+self.sample_per_draw <= self.max_queries: queries += self.sample_per_draw x_adv.requires_grad = True self.model.zero_grad() grad = self.nes_gradient(x_adv, y_victim, y_target) grad = grad + momentum * self.decay momentum = grad if self.p==np.inf: updates = grad.sign() else: normVal = torch.norm(grad.view(batchsize, -1), self.p, 1) updates = grad/normVal.view(batchsize, 1, 1, 1) updates = updates*self.step_size x_adv = x_adv + updates delta = x_adv-x_victim delta = self.clip_eta(batchsize, delta, self.p, self.epsilon) x_adv = torch.clamp(x_victim + delta, min=self.min_value, max=self.max_value).detach() if self._is_adversarial(x_adv, y_victim, y_target): self.detail['success'] = True break self.detail['queries'] = queries return x_adv
def __call__(self, images=None, labels=None, target_labels=None): '''This function perform attack on target images with corresponding labels and target labels for target attack. Args: images (torch.Tensor): The images to be attacked. The images should be torch.Tensor with shape [N, C, H, W] and range [0, 1]. labels (torch.Tensor): The corresponding labels of the images. The labels should be torch.Tensor with shape [N, ] target_labels (torch.Tensor): The target labels for target attack. The labels should be torch.Tensor with shape [N, ] Returns: torch.Tensor: Adversarial images with value range [0,1]. ''' adv_images = [] self.detail = {} for i in range(len(images)): if target_labels is None: target_label = None else: target_label = target_labels[i].unsqueeze(0) adv_x = self.nes(images[i].unsqueeze(0), labels[i].unsqueeze(0), target_label) adv_images.append(adv_x) adv_images = torch.cat(adv_images, 0) return adv_images