Source code for ares.attack.spsa

import numpy as np
import torch
from torch import optim
from ares.utils.registry import registry

[docs]@registry.register_attack('spsa') class SPSA(object): ''' Simultaneous Perturbation Stochastic Approximation (SPSA). A black-box constraint-based method. Use SPSA as gradient estimation technique and employ Adam with this estimated gradient to generate the adversarial example. Example: >>> from ares.utils.registry import registry >>> attacker_cls = registry.get_attack('spsa') >>> attacker = attacker_cls(model) >>> adv_images = attacker(images, labels, target_labels) - Supported distance metric: 1, 2, np.inf. - References: https://arxiv.org/abs/1802.05666. '''
[docs] def __init__(self, model, device='cuda', norm=np.inf, eps=4/255, learning_rate=0.01, delta=0.01, spsa_samples=10, sample_per_draw=1, nb_iter=20, early_stop_loss_threshold=None, target=False): '''The initialize function for SPSA. Args: model (torch.nn.Module): The target model to be attacked. device (torch.device): The device to perform autoattack. Defaults to 'cuda'. norm (float): The norm of distance calculation for adversarial constraint. Defaults to np.inf. eps (float): The maximum perturbation range epsilon. learning_rate (float): The learning rate of attack. delta (float): The delta param. spsa_samples (int): Number of samples in SPSA. sample_per_draw (int): Sample in each draw. nb_iter (int): Number of iteration. early_stop_loss_threshold (float): The threshold for early stop. target (bool): Conduct target/untarget attack. Defaults to False. ''' self.model = model self.device = device self.IsTargeted = target self.eps = eps self.learning_rate = learning_rate self.delta = delta spsa_samples = spsa_samples if spsa_samples else sample_per_draw self.spsa_samples = (spsa_samples // 2) *2 self.sample_per_draw = (sample_per_draw // self.spsa_samples) * self.spsa_samples assert self.sample_per_draw > 0, "Sample per draw must not be zero." self.nb_iter = nb_iter self.norm = norm self.early_stop_loss_threshold = early_stop_loss_threshold self.clip_min = 0 self.clip_max = 1
[docs] def clip_eta(self, batchsize, eta, norm, eps): '''The function to clip image according to the constraint.''' if norm == np.inf: eta = torch.clamp(eta, -eps, eps) elif norm == 2: normVal = torch.norm(eta.view(batchsize, -1), self.norm, 1) mask = normVal<=eps scaling = eps/normVal scaling[mask] = 1 eta = eta*scaling.view(batchsize, 1, 1, 1) else: raise NotImplementedError return eta
def _get_batch_sizes(self, n, max_batch_size): batches = [max_batch_size for _ in range(n // max_batch_size)] if n % max_batch_size > 0: batches.append(n % max_batch_size) return batches def _compute_spsa_gradient(self, loss_fn, x, delta, nb_sample, max_batch_size): '''The function to calculate the gradient of SPSA.''' grad = torch.zeros_like(x) x = x.unsqueeze(0) x = x.expand(max_batch_size, *x.shape[1:]).contiguous() v = torch.empty_like(x[:, :1, ...]) for batch_size in self._get_batch_sizes(nb_sample, max_batch_size): x_ = x[:batch_size] vb = v[:batch_size] vb = vb.bernoulli_().mul_(2.0).sub_(1.0) v_ = vb.expand_as(x_).contiguous() x_shape = x_.shape x_ = x_.view(-1, *x.shape[2:]) v_ = v_.view(-1, *v.shape[2:]) df = loss_fn(delta * v_) - loss_fn(- delta * v_) df = df.view(-1, *[1 for _ in v_.shape[1:]]) grad_ = df / (2. * delta * v_) grad_ = grad_.view(x_shape) grad_ = grad_.sum(dim=0, keepdim=False) grad += grad_ grad /= nb_sample return grad def _is_adversarial(self,x, y, y_target): '''The function to judge if the input image is adversarial.''' output = torch.argmax(self.model(x), dim=1) if self.IsTargeted: return output == y_target else: return output != y def _margin_logit_loss(self, logits, labels, target_label): '''The function to calculate the marginal logits.''' if self.IsTargeted: correct_logits = logits.gather(1, target_label[:, None]).squeeze(1) logit_indices = torch.arange(logits.size()[1], dtype=target_label.dtype, device=target_label.device)[None, :].expand(target_label.size()[0], -1) incorrect_logits = torch.where(logit_indices == target_label[:, None], torch.full_like(logits, float("-inf")), logits) max_incorrect_logits, _ = torch.max(incorrect_logits, 1) return max_incorrect_logits -correct_logits else: correct_logits = logits.gather(1, labels[:, None]).squeeze(1) logit_indices = torch.arange(logits.size()[1], dtype=labels.dtype, device=labels.device)[None, :].expand(labels.size()[0], -1) incorrect_logits = torch.where(logit_indices == labels[:, None], torch.full_like(logits, float("-inf")), logits) max_incorrect_logits, _ = torch.max(incorrect_logits, 1) return -(max_incorrect_logits-correct_logits)
[docs] def spsa(self,x, y, y_target): '''The main function of SPSA attack.''' device = self.device eps = self.eps batchsize = x.shape[0] learning_rate = self.learning_rate delta = self.delta spsa_samples = self.spsa_samples nb_iter = self.nb_iter v_x = x.to(device) v_y = y.to(device) if self._is_adversarial(v_x, v_y, y_target): self.detail['queries'] = 0 self.detail['success'] = True return v_x perturbation = (torch.rand_like(v_x) * 2 - 1) * eps optimizer = optim.Adam([perturbation], lr=learning_rate) self.detail['success'] = False queries = 0 while queries+self.sample_per_draw <= nb_iter: queries += self.sample_per_draw def loss_fn(pert): input1 = v_x + pert input1 = torch.clamp(input1, self.clip_min, self.clip_max) logits = self.model(input1) return self._margin_logit_loss(logits, v_y.expand(len(pert)), y_target.expand(len(pert))) if self.IsTargeted else self._margin_logit_loss(logits, v_y.expand(len(pert)), None) spsa_grad = self._compute_spsa_gradient(loss_fn, v_x, delta=delta, nb_sample=spsa_samples, max_batch_size=self.sample_per_draw) perturbation.grad = spsa_grad optimizer.step() clip_perturbation = self.clip_eta(batchsize, perturbation, self.norm, eps) adv_image = torch.clamp(v_x + clip_perturbation, self.clip_min, self.clip_max) perturbation.add_((adv_image - v_x) - perturbation) loss = loss_fn(perturbation).item() if (self.early_stop_loss_threshold is not None and loss < self.early_stop_loss_threshold): break if self._is_adversarial(adv_image, v_y, y_target): self.detail['success'] = True break self.detail['queries'] = queries return adv_image
def __call__(self, images=None, labels=None, target_labels=None): '''This function perform attack on target images with corresponding labels and target labels for target attack. Args: images (torch.Tensor): The images to be attacked. The images should be torch.Tensor with shape [N, C, H, W] and range [0, 1]. labels (torch.Tensor): The corresponding labels of the images. The labels should be torch.Tensor with shape [N, ] target_labels (torch.Tensor): The target labels for target attack. The labels should be torch.Tensor with shape [N, ] Returns: torch.Tensor: Adversarial images with value range [0,1]. ''' adv_images = [] self.detail = {} for i in range(len(images)): if target_labels is None: target_label = None else: target_label = target_labels[i].unsqueeze(0) adv_x = self.spsa(images[i].unsqueeze(0), labels[i].unsqueeze(0), target_label) adv_images.append(adv_x) adv_images = torch.cat(adv_images, 0) return adv_images