Source code for ares.model.cifar10_cls

import os
import gdown
import torch
from collections import OrderedDict
from ares.model.wideresnet import create_wres28_10, create_wres34_10, create_wres34_10_fn
from ares.model.preact_resnet import create_preact_res18
from ares.model.cifar_model_zoo import cifar_model_zoo
from ares.utils.model import NormalizeByChannelMeanStd
from ares.utils.registry import registry

[docs]def filter_state_dict(state_dict): '''The function to filter state dict''' if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] elif 'net' in state_dict.keys(): state_dict = state_dict['net'] new_state_dict = OrderedDict() for k, v in state_dict.items(): if 'sub_block' in k: continue if 'module.' in k: k = k[7:] if 'basic_net.' in k: k = k[10:] new_state_dict[k] = v return new_state_dict
[docs]@registry.register_model('CifarCLS') class CifarCLS(torch.nn.Module): '''The class to create cifar10 model.'''
[docs] def __init__(self, model_name, normalize=True): ''' Args: model_name (str): The model name in the cifar10 model zoo. normalize (bool): Whether interating the normalization layer into the model. ''' super().__init__() self.model_name = model_name self.normalize = normalize self.arch = cifar_model_zoo[self.model_name]['model'] mean=cifar_model_zoo[self.model_name]['mean'] std=cifar_model_zoo[self.model_name]['std'] if self.arch == 'preact_resnet18': self.model = create_preact_res18() elif self.arch == 'wresnet34_10_fn': self.model = create_wres34_10_fn() elif self.arch == 'wresnet28_10': self.model = create_wres28_10() elif self.arch == 'wresnet34_10': self.model = create_wres34_10() else: raise ValueError('Model not supported.') self.url = cifar_model_zoo[self.model_name]['url'] self.pt_name = cifar_model_zoo[self.model_name]['pt'] self.model_path = os.path.join(registry.get_path('cache_dir'), self.pt_name), self.model_path, quiet=False, resume=True) self.load() if self.normalize: normalization = NormalizeByChannelMeanStd(mean=mean, std=std) self.model = torch.nn.Sequential(normalization, self.model)
[docs] def forward(self, x): ''' Args: x (torch.Tensor): The input images. The images should be torch.Tensor with shape [N, C, H, W] and range [0, 1]. Returns: torch.Tensor: The output logits with shape [N D]. ''' labels = self.model(x) return labels
[docs] def load(self): '''The function to load ckpt.''' checkpoint = torch.load(self.model_path, map_location='cpu') checkpoint = filter_state_dict(checkpoint) self.model.load_state_dict(checkpoint) self.model.eval()