Source code for ares.attack.detection.custom.lr_scheduler

from __future__ import print_function
from torch.optim import lr_scheduler
from ares.utils.registry import Registry
from ares.utils.logger import format_print
[docs]@Registry.register_lr_scheduler() class ALRS: """Reference:Bootstrap Generalization Ability from Loss Landscape Perspective."""
[docs] def __init__(self, optimizer, loss_threshold=1e-4, loss_ratio_threshold=1e-4, decay_rate=0.97, patience=10, last_epoch=-1, verbose=False): self.optimizer = optimizer self.loss_threshold = loss_threshold self.decay_rate = decay_rate self.loss_ratio_threshold = loss_ratio_threshold self.last_loss = 999 self.total_epoch_loss = 0 self.patience = patience self.last_epoch = last_epoch self.verbose = verbose
[docs] def update_lr(self, loss): delta = self.last_loss - loss if delta < self.loss_threshold and delta / self.last_loss < self.loss_ratio_threshold: for ind, group in enumerate(self.optimizer.param_groups): self.optimizer.param_groups[ind]['lr'] *= self.decay_rate now_lr = group['lr'] if self.verbose: print(f'now lr = {now_lr}')
[docs] @format_print() def step(self, loss, epoch=None): if epoch is None: self.last_epoch += 1 else: self.last_epoch = epoch if self.last_epoch % self.patience != 0: self.total_epoch_loss += loss else: loss = self.total_epoch_loss / self.patience self.update_lr(loss) self.last_loss = loss self.total_epoch_loss = 0
[docs]@Registry.register_lr_scheduler() class warmupALRS(ALRS): """Reference:Bootstrap Generalization Ability from Loss Landscape Perspective"""
[docs] def __init__(self, optimizer, warmup_epoch=50, loss_threshold=1e-4, loss_ratio_threshold=1e-4, decay_rate=0.97, last_epoch=-1, verbose=False): super().__init__(optimizer, loss_threshold, loss_ratio_threshold, decay_rate, last_epoch, verbose) self.warmup_rate = 1 / 3 self.warmup_epoch = warmup_epoch self.start_lr = optimizer.param_groups[0]["lr"] self.warmup_lr = self.start_lr * (1 - self.warmup_rate) self.update_lr(lambda x: x * self.warmup_rate)
[docs] def update_lr(self, update_fn): for ind, group in enumerate(self.optimizer.param_groups): self.optimizer.param_groups[ind]['lr'] = update_fn(group['lr']) now_lr = group['lr'] if self.verbose: print(f'now lr = {now_lr}')
[docs] @format_print() def step(self, loss, epoch=None): if epoch is None: self.last_epoch += 1 else: self.last_epoch = epoch if self.last_epoch < self.warmup_epoch: self.update_lr(lambda x: -(self.warmup_epoch - epoch) * self.warmup_lr / self.warmup_epoch + self.start_lr) elif self.last_epoch % self.patience != 0: self.total_epoch_loss += loss else: loss = self.total_epoch_loss / self.patience delta = self.last_loss - loss self.last_loss = loss if delta < self.loss_threshold and delta / self.last_loss < self.loss_ratio_threshold: self.update_lr(lambda x: x * self.decay_rate)
[docs]@Registry.register_lr_scheduler() class CosineLR(lr_scheduler.CosineAnnealingLR): '''See torch.optim.lr_scheduler.CosineAnnealingLR for details'''
[docs] def __init__(self, optimizer, T_max, eta_min=0, last_epoch=-1, verbose=False): super().__init__(optimizer, T_max, eta_min, last_epoch, verbose)
[docs] @format_print() def step(self, epoch=None, **kwargs) -> None: super().step(epoch)
[docs]@Registry.register_lr_scheduler() class ExponentialLR(lr_scheduler.ExponentialLR): '''See torch.optim.lr_scheduler.ExponentialLR for details'''
[docs] def __init__(self, optimizer, gamma, last_epoch=-1, verbose=False): super().__init__(optimizer, gamma, last_epoch, verbose)
[docs] @format_print() def step(self, epoch=None, **kwargs) -> None: super().step(epoch)
[docs]@Registry.register_lr_scheduler() class PlateauLR(lr_scheduler.ReduceLROnPlateau): '''See torch.optim.lr_scheduler.ReduceLROnPlateau for details'''
[docs] def __init__(self, optimizer, mode='min', factor=0.1, patience=10, threshold=1e-4, threshold_mode='rel', cooldown=0, min_lr=0, eps=1e-8, verbose=False): super().__init__(optimizer, mode, factor, patience, threshold, threshold_mode, cooldown, min_lr, eps, verbose)
[docs] @format_print() def step(self, metrics, epoch=None, **kwargs) -> None: super().step(metrics, epoch)
[docs]@Registry.register_lr_scheduler() class MultiStepLR(lr_scheduler.MultiStepLR): '''See torch.optim.lr_scheduler.MultiStepLR for details'''
[docs] def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False): super().__init__(optimizer, milestones, gamma, last_epoch, verbose)
[docs] @format_print() def step(self, epoch=None, **kwargs) -> None: super().step(epoch)
[docs]def build_lr_scheduler(optimizer, **kwargs): '''build learning rate scheduler based on given optimizer, lr scheduler name and its arguments''' return Registry.get_lr_scheduler(kwargs['type'])(optimizer, **kwargs['kwargs'])