import torch
import torch.nn as nn
from timm.models import create_model, safe_model_name
from timm.models.layers import convert_splitbn_model
[docs]def normalize_fn(tensor, mean, std):
"""Differentiable version of torchvision.functional.normalize"""
# Here we assume the color channel is in at dim=1
mean = mean[None, :, None, None]
std = std[None, :, None, None]
return tensor.sub(mean).div(std)
[docs]def denormalize(tensor, mean, std):
'''
Args:
tensor (torch.Tensor): Float tensor image of size (B, C, H, W) to be denormalized.
mean (torch.Tensor): float tensor means of size (C, ) for each channel.
std (torch.Tensor): float tensor standard deviations of size (C, ) for each channel.
'''
return tensor*std[None]+mean[None]
[docs]class NormalizeByChannelMeanStd(nn.Module):
'''The class of a normalization layer.'''
[docs] def __init__(self, mean, std):
super(NormalizeByChannelMeanStd, self).__init__()
if not isinstance(mean, torch.Tensor):
mean = torch.tensor(mean)
if not isinstance(std, torch.Tensor):
std = torch.tensor(std)
self.register_buffer("mean", mean)
self.register_buffer("std", std)
[docs] def forward(self, tensor):
return normalize_fn(tensor, self.mean, self.std)
[docs]class SwitchableBatchNorm2d(torch.nn.BatchNorm2d):
'''The class of a batch norm layer.'''
[docs] def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
track_running_stats=True):
super().__init__(num_features, eps, momentum, affine, track_running_stats)
self.bn_mode = 'clean'
self.bn_adv = nn.BatchNorm2d(num_features, eps, momentum, affine, track_running_stats)
[docs] def forward(self, input: torch.Tensor):
if self.training: # aux BN only relevant while training
if self.bn_mode == 'clean':
return super().forward(input)
elif self.bn_mode == 'adv':
return self.bn_adv(input)
else:
return super().forward(input)
[docs]def convert_switchablebn_model(module):
"""
Recursively traverse module and its children to replace all instances of
``torch.nn.modules.batchnorm._BatchNorm`` with `SplitBatchnorm2d`.
Args:
module (torch.nn.Module): input module
num_splits (int): number of separate batchnorm layers to split input across
Example::
>>> # model is an instance of torch.nn.Module
>>> model = timm.models.convert_splitbn_model(model, num_splits=2)
"""
mod = module
if isinstance(module, torch.nn.modules.instancenorm._InstanceNorm):
return module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
mod = SwitchableBatchNorm2d(
module.num_features, module.eps, module.momentum, module.affine,
module.track_running_stats)
mod.running_mean = module.running_mean
mod.running_var = module.running_var
mod.num_batches_tracked = module.num_batches_tracked
if module.affine:
mod.weight.data = module.weight.data.clone().detach()
mod.bias.data = module.bias.data.clone().detach()
for aux in [mod.bn_adv]:
aux.running_mean = module.running_mean.clone()
aux.running_var = module.running_var.clone()
aux.num_batches_tracked = module.num_batches_tracked.clone()
if module.affine:
aux.weight.data = module.weight.data.clone().detach()
aux.bias.data = module.bias.data.clone().detach()
for name, child in module.named_children():
mod.add_module(name, convert_switchablebn_model(child))
del module
return mod
[docs]def build_model(args, _logger, num_aug_splits):
'''The function to build model for robust training.'''
# creating model
_logger.info(f"Creating model: {args.model}")
model_kwargs=dict({
'num_classes': args.num_classes,
'drop_rate': args.drop,
'drop_connect_rate': args.drop_connect, # DEPRECATED, use drop_path
'drop_path_rate': args.drop_path,
'drop_block_rate': args.drop_block,
'global_pool': args.gp,
'bn_momentum': args.bn_momentum,
'bn_eps': args.bn_eps,
})
model = create_model(args.model, pretrained=False, **model_kwargs)
if args.num_classes is None:
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
args.num_classes = model.num_classes
_logger.info(f'Model {safe_model_name(args.model)} created, param count:{sum([m.numel() for m in model.parameters()])}')
# enable split bn (separate bn stats per batch-portion)
if args.split_bn:
assert num_aug_splits > 1 or args.resplit
model = convert_splitbn_model(model, max(num_aug_splits, 2))
# advprop conversion
if args.advprop:
model=convert_switchablebn_model(model)
model.cuda()
if args.channels_last:
model = model.to(memory_format=torch.channels_last)
# setup synchronized BatchNorm for distributed training
if args.distributed and args.sync_bn:
if args.amp_version == 'apex':
# Apex SyncBN preferred unless native amp is activated
from apex.parallel import convert_syncbn_model
model = convert_syncbn_model(model)
else:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
_logger.info(
'Converted model to use Synchronized BatchNorm. WARNING: You may have issues if using '
'zero initialized BN layers (enabled by default for ResNets) while sync-bn enabled.')
return model
[docs]def load_pretrained_21k(args, model, logger):
'''The function to load pretrained 21K checkpoint to 1K model.'''
logger.info(f"==============> Loading weight {args.pretrain} for fine-tuning......")
checkpoint = torch.load(args.pretrain, map_location='cpu')
state_dict = checkpoint['model']
# delete relative_position_index since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_position_index" in k]
for k in relative_position_index_keys:
del state_dict[k]
# delete relative_coords_table since we always re-init it
relative_position_index_keys = [k for k in state_dict.keys() if "relative_coords_table" in k]
for k in relative_position_index_keys:
del state_dict[k]
# delete attn_mask since we always re-init it
attn_mask_keys = [k for k in state_dict.keys() if "attn_mask" in k]
for k in attn_mask_keys:
del state_dict[k]
# bicubic interpolate relative_position_bias_table if not match
relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
for k in relative_position_bias_table_keys:
relative_position_bias_table_pretrained = state_dict[k]
relative_position_bias_table_current = model.state_dict()[k]
L1, nH1 = relative_position_bias_table_pretrained.size()
L2, nH2 = relative_position_bias_table_current.size()
if nH1 != nH2:
logger.warning(f"Error in loading {k}, passing......")
else:
if L1 != L2:
# bicubic interpolate relative_position_bias_table if not match
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
relative_position_bias_table_pretrained_resized = torch.nn.functional.interpolate(
relative_position_bias_table_pretrained.permute(1, 0).view(1, nH1, S1, S1), size=(S2, S2),
mode='bicubic')
state_dict[k] = relative_position_bias_table_pretrained_resized.view(nH2, L2).permute(1, 0)
# bicubic interpolate absolute_pos_embed if not match
absolute_pos_embed_keys = [k for k in state_dict.keys() if "absolute_pos_embed" in k]
for k in absolute_pos_embed_keys:
# dpe
absolute_pos_embed_pretrained = state_dict[k]
absolute_pos_embed_current = model.state_dict()[k]
_, L1, C1 = absolute_pos_embed_pretrained.size()
_, L2, C2 = absolute_pos_embed_current.size()
if C1 != C1:
logger.warning(f"Error in loading {k}, passing......")
else:
if L1 != L2:
S1 = int(L1 ** 0.5)
S2 = int(L2 ** 0.5)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.reshape(-1, S1, S1, C1)
absolute_pos_embed_pretrained = absolute_pos_embed_pretrained.permute(0, 3, 1, 2)
absolute_pos_embed_pretrained_resized = torch.nn.functional.interpolate(
absolute_pos_embed_pretrained, size=(S2, S2), mode='bicubic')
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.permute(0, 2, 3, 1)
absolute_pos_embed_pretrained_resized = absolute_pos_embed_pretrained_resized.flatten(1, 2)
state_dict[k] = absolute_pos_embed_pretrained_resized
# check classifier, if not match, then re-init classifier to zero
head_bias_pretrained = state_dict['head.bias']
Nc1 = head_bias_pretrained.shape[0]
Nc2 = model.head.bias.shape[0]
if (Nc1 != Nc2):
if Nc1 == 21841 and Nc2 == 1000:
logger.info("loading ImageNet-22K weight to ImageNet-1K ......")
map22kto1k_path = f'data/map22kto1k.txt'
with open(map22kto1k_path) as f:
map22kto1k = f.readlines()
map22kto1k = [int(id22k.strip()) for id22k in map22kto1k]
state_dict['head.weight'] = state_dict['head.weight'][map22kto1k, :]
state_dict['head.bias'] = state_dict['head.bias'][map22kto1k]
else:
torch.nn.init.constant_(model.head.bias, 0.)
torch.nn.init.constant_(model.head.weight, 0.)
del state_dict['head.weight']
del state_dict['head.bias']
logger.warning(f"Error in loading classifier head, re-init classifier head to 0")
msg = model.load_state_dict(state_dict, strict=False)
logger.warning(msg)
logger.info(f"=> loaded successfully '{args.pretrain}'")
del checkpoint
torch.cuda.empty_cache()