Source code for ares.model.imagenet_cls

import torch
import os
import gdown
import torch.nn as nn
from timm.models import create_model
from ares.utils.model import NormalizeByChannelMeanStd
from ares.model.resnet import resnet50, wide_resnet50_2, ResNetGELU
from ares.model.resnet_denoise import resnet152_fd
from ares.model import vit_mae
from ares.model.imagenet_model_zoo import imagenet_model_zoo
from ares.utils.registry import registry

[docs]@registry.register_model('ImageNetCLS') class ImageNetCLS(torch.nn.Module): '''The class to create ImageNet model.'''
[docs] def __init__(self, model_name, normalize=True): ''' Args: model_name (str): The model name in the ImageNet model zoo. normalize (bool): Whether interating the normalization layer into the model. ''' super().__init__() self.model_name = model_name self.normalize = normalize self.backbone = imagenet_model_zoo[self.model_name]['model'] mean=imagenet_model_zoo[self.model_name]['mean'] std=imagenet_model_zoo[self.model_name]['std'] self.pretrained=imagenet_model_zoo[self.model_name]['pretrained'] act_gelu=imagenet_model_zoo[self.model_name]['act_gelu'] if self.backbone=='resnet50_rl': model=resnet50() elif self.backbone=='wide_resnet50_2_rl': model=wide_resnet50_2() elif self.backbone=='resnet152_fd': model = resnet152_fd() elif self.backbone=='vit_base_patch16' or self.backbone=='vit_large_patch16': model=vit_mae.__dict__[self.backbone](num_classes=1000, global_pool='') else: model_kwargs=dict({'num_classes': 1000}) if act_gelu: model_kwargs['act_layer']=ResNetGELU model = create_model(self.backbone, pretrained=self.pretrained, **model_kwargs) self.model=model self.url = imagenet_model_zoo[self.model_name]['url'] ckpt_name = '' if self.pretrained else imagenet_model_zoo[self.model_name]['pt'] self.model_path=os.path.join(registry.get_path('cache_dir'), ckpt_name) if self.url: gdown.download(self.url, self.model_path, quiet=False, resume=True) self.load() if self.normalize: normalization = NormalizeByChannelMeanStd(mean=mean, std=std) self.model = torch.nn.Sequential(normalization, self.model)
[docs] def forward(self, x): ''' Args: x (torch.Tensor): The input images. The images should be torch.Tensor with shape [N, C, H, W] and range [0, 1]. Returns: torch.Tensor: The output logits with shape [N D]. ''' labels = self.model(x) return labels
[docs] def load(self): '''The function to load ckpt.''' if not self.pretrained: ckpt=torch.load(self.model_path, map_location='cpu') self.model.load_state_dict(ckpt)