Source code for ares.attack.detection.trainer

import os
import time

import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from tqdm import tqdm

from .custom.lr_scheduler import build_lr_scheduler
from .utils import build_optimizer
from .utils import get_word_size
from .utils import is_distributed
from .utils import mkdirs_if_not_exists
from .utils import save_images
from .utils import all_reduce

[docs]class Trainer(): """Base trainer class. Args: cfg (mmengine.config.ConfigDict): Attack config dict. model (torch.nn.Module): Model to be trained or evaluated. train_dataloader (torch.utils.data.Dataloader): Dataloader for training. test_dataloader (torch.utils.data.Dataloader): Dataloader for testing. evaluator (class): Evaluator to evaluate detection performance. logger (logging.Logger): Logger to record information. """
[docs] def __init__(self, cfg, model, train_dataloader, test_dataloader, evaluator, logger): self.cfg = cfg self.model = model self.train_dataloader = train_dataloader self.test_dataloader = test_dataloader self.evaluator = evaluator self.logger = logger self.epochs = self.cfg.get('epochs', 0) self.before_start()
[docs] def train(self): """Train model.""" self.before_train() for epoch in range(1, self.epochs + 1): self.runtime['epoch'] = epoch self.before_epoch() self.run_epoch() self.after_epoch() self.after_train()
[docs] @torch.no_grad() def eval(self, eval_on_clean=False): '''Evaluate detection performance.''' self.before_eval() if eval_on_clean: self.eval_clean() save_adv_images = self.cfg.adv_image.save if save_adv_images: adv_image_save_dir = os.path.join(self.cfg.log_dir, self.cfg.adv_image.save_folder) mkdirs_if_not_exists(adv_image_save_dir) if self.is_distributed: dist.barrier() self.logger.info('Evaluating detection performance on attacked data...') for i, batch_data in tqdm(enumerate(self.test_dataloader), total=len(self.test_dataloader)): with torch.cuda.amp.autocast(enabled=self.cfg.amp): returned_dict = self.model(batch_data) preds = returned_dict['preds'] if save_adv_images: save_images(returned_dict['adv_images'], preds, adv_image_save_dir, self.cfg.adv_image.with_bboxes) self.evaluator.process(data_samples=preds) metrics = self.evaluator.evaluate(len(self.test_dataloader.dataset)) return metrics
[docs] def run_epoch(self): """Train for one epoch.""" epoch_loss = torch.tensor(0.0, device=self.rank, requires_grad=False) t1 = time.time() for i, batch_data in enumerate(self.train_dataloader): with torch.cuda.amp.autocast(enabled=self.cfg.amp): losses = self.model(batch_data) loss = sum(losses.values()) if self.cfg.loss_fn.tv_loss.enable: loss_tv = losses['loss_tv'].detach() all_reduce(loss_tv, 'avg') batch_loss = loss.detach() all_reduce(batch_loss, 'avg') epoch_loss += batch_loss.item() self.runtime['total_loss']['value'] += batch_loss.item() self.runtime['total_loss']['length'] += 1 # some batch images without gt bounding boxes may lead to no grad computed try: self.optimizer.zero_grad() loss.backward() self.optimizer.step() # patch values should be in range [0, 1] if self.is_distributed: torch.clamp(self.model.module.patch, min=0, max=1) else: torch.clamp(self.model.patch, min=0, max=1) except: pass t2 = time.time() batch_time = t2 - t1 t1 = t2 if i % self.cfg.log_period == 0: epoch = self.runtime['epoch'] length = len(self.train_dataloader) avg_loss = self.runtime['total_loss']['value'] / self.runtime['total_loss']['length'] info = f'Epoch:{epoch:3d}/{self.epochs} Iter:{i:4d}/{length} loss:{batch_loss:.2f} ({avg_loss:.2f}) ' if self.cfg.loss_fn.tv_loss.enable: info += f'loss tv:{loss_tv.item():.2f} ' info += f'time:{batch_time:.1f}s' self.logger.info(info) self.runtime['epoch_loss'] = epoch_loss
[docs] @torch.no_grad() def eval_clean(self): """Evaluate detection performance on clean data.""" if self.cfg.clean_image.save: clean_image_save_dir = os.path.join(self.cfg.log_dir, self.cfg.clean_image.save_folder) mkdirs_if_not_exists(clean_image_save_dir) self.logger.info('Evaluating detection performance on clean data...') model = self.model.module if self.is_distributed else self.model for i, batch_data in tqdm(enumerate(self.test_dataloader), total=len(self.test_dataloader)): with torch.cuda.amp.autocast(enabled=self.cfg.amp): preds, images = model.bbox_predict(batch_data, return_images=True) self.evaluator.process(data_samples=preds) if self.cfg.clean_image.save: save_images(images, preds, clean_image_save_dir, self.cfg.clean_image.with_bboxes) self.evaluator.evaluate(len(self.test_dataloader.dataset))
[docs] def before_eval(self): """Do something before evaluating.""" self.model.eval() if not self.is_distributed else self.model.module.eval() self.test_dataloader.sampler.shuffle = False
[docs] def before_epoch(self): """Do something before each training epoch.""" epoch = self.runtime['epoch'] self.model.train() if not self.is_distributed else self.model.module.train() if self.is_distributed: self.train_dataloader.batch_sampler.sampler.set_epoch(epoch)
[docs] def after_epoch(self): """Do something after each training epoch.""" epoch = self.runtime['epoch'] all_reduce(self.runtime['epoch_loss'], reduction='sum') epoch_loss = self.runtime['epoch_loss'] / len(self.train_dataloader) self.lr_scheduler.step(loss=epoch_loss.item(), epoch=epoch) if epoch % self.cfg.patch.save_period == 0: model = self.model.module if self.is_distributed else self.model model.save_patch(self.runtime['epoch'], is_best=False) if epoch % self.cfg.eval_period == 0 and epoch != self.epochs: metrics = self.eval(eval_on_clean=False) is_best = metrics['coco/bbox_mAP'] < self.runtime['lowest_bbox_mAP'] if is_best: self.runtime['lowest_bbox_mAP'] = metrics['coco/bbox_mAP'] self.logger.info('Lowest mAP updated!') model = self.model.module if self.is_distributed else self.model model.save_patch(self.runtime['epoch'], is_best=is_best)
[docs] def before_train(self): """Automatically scale learning rate, build optimizer and lr_scheduler before training.""" self.train_dataloader.sampler.shuffle = True if self.is_distributed: params = [{'params': self.model.module.patch}] else: params = [{'params': self.model.patch}] self.scale_lr() self.optimizer = build_optimizer(params, **self.cfg.optimizer) self.lr_scheduler = build_lr_scheduler(self.optimizer, **self.cfg.lr_scheduler)
[docs] def after_train(self): """Do something after finishing training.""" metrics = self.eval(eval_on_clean=False) if self.rank == 0: patch_save_dir = os.path.join(self.cfg.log_dir, self.cfg.patch.save_folder) mkdirs_if_not_exists(patch_save_dir) is_best = metrics['coco/bbox_mAP'] < self.runtime['lowest_bbox_mAP'] if is_best: self.runtime['lowest_bbox_mAP'] = metrics['coco/bbox_mAP'] self.logger.info('Lowest mAP updated!') model = self.model.module if self.is_distributed else self.model model.save_patch(self.runtime['epoch'], is_best=is_best)
[docs] def before_start(self): """Initialization before starting training or evaluating.""" if self.cfg.attack_mode == 'patch': if self.cfg.patch.get('resume_path'): self.model.load_patch(self.cfg.patch.resume_path) self.runtime = {} # to store some variables during the training period self.runtime['total_loss'] = {'value': 0, 'length': 0} self.runtime['epoch'] = 1 self.runtime['lowest_bbox_mAP'] = 1.0 self.is_distributed = is_distributed() self.rank = 0 if not self.is_distributed else dist.get_rank() self.device = torch.device(self.rank) self.world_size = get_word_size() if self.is_distributed else 1 self.model = self.model.to(self.device) self.model.detector.to(self.device) if self.is_distributed: self.model = DistributedDataParallel(self.model, device_ids=[self.device], output_device=self.device, find_unused_parameters=True, ) self.model.module.freeze_layers(self.model.module.detector) else: self.model.freeze_layers(self.model.detector)
[docs] def scale_lr(self): '''Automatically scale learning rate based on base batch size and real batch size''' if self.cfg.get('auto_lr_scaler'): base_batch_size = self.cfg.auto_lr_scaler.base_batch_size real_batch_size = self.world_size * self.cfg.batch_size ratio = float(real_batch_size) / float(base_batch_size) self.cfg.optimizer.kwargs.lr *= ratio