Robust Training for Image Classification

Overview

This project contains the code for adversarial training on classification models. Heavy augmentations can be used in adversarial training. The well-known timm project is used as the default classification library.

Preparation

  • Dataset

    We use ImageNet Validation Set as the default dataset to evaluate adversarial robustness of classification models. Please download ImageNet dataset first. If you want to use your own datasets, please define their torch.utils.data.Dataset class and corresponding transform.

  • Classification Models

    To build a image classification model, you can create a model class from timm library or you can define custom network of torch.nn.Module.

Robust Training

Before start, modify the default attacking config files in train_configs. It contains the augmentation parameters, dataset directory, etc.

Then, you can run the following command to start training ResNet50.

cd robust_training
python -m torch.distributed.launch --nproc_per_node=<num-of-gpus-to-use> adversarial_training.py --configs=./train_configs/resnet50.yaml