import os
import torch
import numpy as np
import torch.nn as nn
from ares.utils.registry import Registry
from mmengine.structures import InstanceData
from .patch.patch_applier import PatchApplier
from .utils import EnableLossCal
from .utils import normalize, denormalize, main_only
from .utils import tv_loss, mkdirs_if_not_exists, save_patches_to_images
[docs]class UniversalAttacker(nn.Module):
'''Class supports both global perturbation attack and patch attack.
Args:
cfg (mmengine.config.ConfigDict): Configs for adversarial attack.
detector (torch.nn.Module): Detector to be attacked.
logger (logging.Logger): Logger to record logs.
device (torch.device): torch.device. Default: torch.device(0).
'''
[docs] def __init__(self, cfg, detector, logger, device=torch.device(0)):
super().__init__()
self.cfg = cfg
self.logger = logger
self.detector = detector
self.load_detector_weight()
self.data_preprocessor = detector.data_preprocessor
self.device = device
self.detector_image_max_val = None
if self.cfg.attack_mode == 'patch':
self.init_for_patch_attack()
elif self.cfg.attack_mode == 'global':
self.init_for_global_attack()
else:
raise ValueError('Supported attack modes are patch or global, but got %s instead.' % self.cfg.attack_mode)
[docs] def forward(self, batch_data, return_adv_images_only=False):
'''
Args:
batch_data (dict): Input batch data. Example: {'inputs': torch.Tensor with shape [N,C,H,W],
'data_samples':list of mmdet.structures.det_data_sample.DetDataSample with length N.
}
return_adv_images_only (bool): Whether to return adv images only without bboxes prediction.
Default: False
Returns:
dict. It may contain keys losses, adv_images.
'''
batch_data = self.data_preprocessor(batch_data)
if self.cfg.attack_mode == 'patch':
return self.patch_forward(batch_data, return_adv_images_only)
elif self.cfg.attack_mode == 'global':
return self.global_forward(batch_data, return_adv_images_only)
else:
raise ValueError('Supported attack modes are patch or global, but got %s instead.' % self.cfg.attack_mode)
[docs] def global_forward(self, batch_data, return_adv_images_only=False):
'''For global perturbation attack.'''
# denormalize images to range [0, 1]
images = batch_data['inputs']
images = denormalize(images, self.data_preprocessor.mean, self.data_preprocessor.std)
# set image value range. We suppose the range is 0-1 or 0-255.
if self.detector_image_max_val is None:
max_val, min_val = images[0][0].max(), images[0][0].min()
if max_val > 1.5 and min_val >= 0:
self.detector_image_max_val = 255.0
elif max_val <= 1 and min_val >= 0:
self.detector_image_max_val = 1.0
else:
raise ValueError(f"Expected image pixel value range before normalization is [0, 1] or [0, 255], but got min value {min_val}, max value {max_val}!")
images = images / self.detector_image_max_val
batch_data['inputs'] = images
if self.cfg.object_vanish_only:
self.set_gt_ann_empty(batch_data['data_samples'])
with torch.enable_grad():
with EnableLossCal(self.detector):
adv_images = self.attack_method.attack_detection_forward(batch_data, self.cfg.loss_fn.get('excluded_losses', []),
self.detector_image_max_val,
self.cfg.object_vanish_only)
if return_adv_images_only:
return {'adv_images': adv_images}
# normalize adv images for detector input
normed_adv_images = normalize(adv_images * self.detector_image_max_val, self.data_preprocessor.mean,
self.data_preprocessor.std)
preds = self.bbox_predict({'inputs':normed_adv_images, 'data_samples':batch_data['data_samples']},
need_preprocess=False)
returned_dict = {'preds': preds, 'adv_images': adv_images}
return returned_dict
[docs] def patch_forward(self, batch_data, return_adv_images_only=False):
'''For patch attack'''
# denormalize images to range [0, 1]
images = batch_data['inputs']
images = denormalize(images, self.data_preprocessor.mean, self.data_preprocessor.std)
# set image value range. We suppose the range is 0-1 or 0-255.
if self.detector_image_max_val is None:
max_val, min_val = images[0][0].max(), images[0][0].min()
if max_val > 1.5 and min_val >= 0:
self.detector_image_max_val = 255.0
elif max_val <= 1 and min_val >= 0:
self.detector_image_max_val = 1.0
else:
raise ValueError(
f"Expected image pixel value range before normalization is [0, 1] or [0, 255], but got min value {min_val}, max value {max_val}!")
images = images / self.detector_image_max_val
bboxes_list, labels_list = [], []
for i, data in enumerate(batch_data['data_samples']):
bboxes = data.gt_instances.bboxes.clone()
labels = data.gt_instances.labels
if self.attacked_labels is None:
bboxes_list.append(bboxes)
labels_list.append(labels)
else:
mask = (labels[:, None] == self.attacked_labels).any(dim=1)
bboxes_list.append(bboxes[mask])
labels_list.append(labels[mask])
adv_images = self.patch_applier(images, self.patch, bboxes_list, labels_list)
if return_adv_images_only:
return {'adv_images': adv_images}
normed_adv_images = normalize(adv_images * self.detector_image_max_val, self.data_preprocessor.mean, self.data_preprocessor.std)
if self.training:
if self.cfg.object_vanish_only:
self.set_gt_ann_empty(batch_data['data_samples'])
detector_losses = self.detector.loss(normed_adv_images, batch_data['data_samples'])
attacked_detector_loss = self.filter_loss(detector_losses)
losses = {'loss_detector': attacked_detector_loss}
if self.cfg.loss_fn.tv_loss.enable:
selected_patch_indices = torch.cat((labels_list)).unique()
selected_patches = self.patch[selected_patch_indices]
loss_tv = tv_loss(selected_patches)
loss_tv = torch.max(self.cfg.loss_fn.tv_loss.tv_scale * loss_tv, torch.tensor(self.cfg.loss_fn.tv_loss.tv_thresh).to(loss_tv.device))
losses.update({'loss_tv': loss_tv})
return losses
else:
preds = self.bbox_predict({'inputs':normed_adv_images, 'data_samples':batch_data['data_samples']},
need_preprocess=False)
returned_dict = {'preds': preds, 'adv_images': adv_images}
return returned_dict
[docs] def init_for_patch_attack(self):
'''Initialize adversarial patch, patch applier and attacked labels for patch attack.'''
self.patch = self.init_patch(init_mode=self.cfg.patch.init_mode)
self.patch_applier = PatchApplier(self.cfg.patch)
self.attacked_labels = self.cfg.get('attacked_labels', None)
if self.attacked_labels:
self.attacked_labels = torch.Tensor(self.attacked_labels).to(self.device)
[docs] def init_for_global_attack(self):
'''Initialize attack method for global attack.'''
norm_type = self.cfg.attack_method.kwargs.norm
if norm_type == 'l2':
self.cfg.attack_method.kwargs.norm = 2
elif norm_type == 'inf':
self.cfg.attack_method.kwargs.norm = np.inf
else:
raise ValueError('Only l2 and inf norm are supported, bu got %s instead' % norm_type)
self.attack_method = Registry.get_attack(self.cfg.attack_method.type)(self.detector, device=self.device,
**self.cfg.attack_method.kwargs)
[docs] def set_gt_ann_empty(self, data_samples):
'''Set gt bboxes and gt labels zero tensors for object_vanish_only goal.'''
bboxes = torch.zeros((1, 4), dtype=torch.float32, device=self.device)
labels = torch.zeros((1,), dtype=torch.long, device=self.device)
empty_gt = InstanceData(bboxes=bboxes, labels=labels, metainfo={})
for data_sample in data_samples:
data_sample.gt_instances = empty_gt
[docs] def filter_loss(self, losses):
'''Collect losses not in self.cfg.loss_fn.excluded_losses.'''
loss_list = []
for key in losses.keys():
if isinstance(losses[key], list):
losses[key] = torch.stack(losses[key]).mean()
kept = True
for excluded_loss in self.cfg.loss_fn.excluded_losses:
if excluded_loss in key:
kept = False
continue
if kept and 'loss' in key:
loss_list.append(losses[key].mean().unsqueeze(0))
if self.cfg.object_vanish_only:
loss = torch.stack(loss_list).mean()
else:
loss = -torch.stack(loss_list).mean()
return loss
[docs] def freeze_layers(self, modules):
'''Freeze given modules via setting their requires_grad attribute False.'''
for _, parameter in modules.named_parameters():
parameter.requires_grad = False
[docs] def init_patch(self, init_mode='gray'):
'''Initialize adversarial patch with given init_mode.'''
assert init_mode in ['gray', 'white', 'black', 'random'], \
'Expected patch initilization modes are gray, while, ' \
'black or ramdom, bug got %s instead' % init_mode
height = self.cfg.patch.size
width = self.cfg.patch.size
try:
num_classes = self.detector.bbox_head.num_classes
except:
num_classes = self.detector.roi_head.bbox_head.num_classes
self.logger.info('Adversarial patches initialzed by %s mode' % init_mode)
if init_mode.lower() == 'random':
patch = torch.rand((num_classes, 3, height, width))
elif init_mode.lower() == 'gray':
patch = torch.full((num_classes, 3, height, width), 0.5)
elif init_mode.lower() == 'white':
patch = torch.full((num_classes, 3, height, width), 1.0)
elif init_mode.lower() == 'black':
patch = torch.full((num_classes, 3, height, width), 0)
patch = nn.Parameter(patch, requires_grad=True)
return patch
[docs] def load_detector_weight(self):
'''Load detector weight from file.'''
self.logger.info('Load detector weight from path: %s' % self.cfg.detector.weight_file)
state_dict = torch.load(self.cfg.detector.weight_file, map_location='cpu')['state_dict']
result = self.detector.load_state_dict(state_dict, strict=False)
self.logger.info(result)
[docs] def train(self, mode: bool = True):
'''Set self to training mode.'''
self.training = mode
for module in self.children():
module.train(mode)
# detector should be set in eval model always!
self.detector.eval()
self.detector.training = True # to make detector.loss() work as in training mode
return self
[docs] def eval(self):
'''Set self to eval mode.'''
self.train(False)
self.detector.training = False
return self
[docs] def load_patch(self, patch_path):
'''Initialize patch with given patch_path'''
self.logger.info('Load adversarial patch from path: %s' % patch_path)
patches = torch.load(patch_path, map_location=self.device)
self.patch = torch.nn.Parameter(patches)
[docs] @main_only
def save_patch(self, epoch=None, is_best=False):
'''Save adversarial patch to file.'''
patch_save_dir = os.path.join(self.cfg.log_dir, self.cfg.patch.save_folder)
mkdirs_if_not_exists(patch_save_dir)
patches = self.patch.detach()
patch_file_path = os.path.join(patch_save_dir, 'patches@epoch-' + str(epoch) + '.pth')
patch_images_path = os.path.join(patch_save_dir, 'patch-images@epoch-' + str(epoch))
if self.cfg.get('attacked_classes'):
patch_classes = self.cfg.attacked_classes
patch_labels = self.cfg.attacked_labels
else:
patch_classes = self.cfg.all_classes
patch_labels = None
torch.save(patches.cpu(), patch_file_path)
save_patches_to_images(patches, patch_images_path, patch_classes, patch_labels)
if is_best:
self.logger.info('save best patches in epoch %d' % epoch)
patch_file_path = os.path.join(patch_save_dir, 'best-patches.pth')
patch_images_path = os.path.join(patch_save_dir, 'best-patch-images')
torch.save(patches.cpu(), patch_file_path)
save_patches_to_images(patches, patch_images_path, patch_classes, patch_labels)
[docs] def bbox_predict(self, batch_data, need_preprocess=True, return_images=False):
"""
Args:
batch_data (dict): A dict contains inputs and data_samples attributes. See self.forward() for details.
need_preprocess (bool): Whether to preprocess batch_data.
return_images (bool): Whether to return input images.
Returns:
list or tuple : If list, return preds which is list of mmdet.structure.DetDataSample containing pred_instances attribute.
If tuple, return (preds, images) where images are batch_data['inputs'], torch.Tensor with shape [N,C,H,W].
"""
if need_preprocess:
batch_data = self.data_preprocessor(batch_data)
preds = self.detector.predict(batch_data['inputs'], batch_data['data_samples'])
if return_images:
images = batch_data['inputs']
images = denormalize(images, self.data_preprocessor.mean, self.data_preprocessor.std)
# set image value range. We suppose the range is 0-1 or 0-255.
if self.detector_image_max_val is None:
max_val, min_val = images[0][0].max(), images[0][0].min()
if max_val > 1.5 and min_val >= 0:
self.detector_image_max_val = 255.0
elif max_val <= 1 and min_val >= 0:
self.detector_image_max_val = 1.0
else:
raise ValueError(
f"Expected image pixel value range before normalization is [0, 1] or [0, 255], but got min value {min_val}, max value {max_val}!")
images = images / self.detector_image_max_val
return preds, images
return preds