Source code for ares.utils.loss

import torch
from torch import nn
from timm.loss import JsdCrossEntropy, BinaryCrossEntropy, LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from contextlib import suppress
from timm.utils import NativeScaler
try:
    from apex import amp
    from timm.utils import ApexScaler
    has_apex = True
except ImportError:
    has_apex = False
has_native_amp = False
try:
    if getattr(torch.cuda.amp, 'autocast') is not None:
        has_native_amp = True
except AttributeError:
    pass


[docs]def loss_adv(loss_name, outputs, labels, target_labels, target, device): '''The function to create loss function.''' if loss_name=="ce": loss = nn.CrossEntropyLoss() if target: cost = -loss(outputs, target_labels) else: cost = loss(outputs, labels) elif loss_name =='cw': if target: one_hot_labels = torch.eye(len(outputs[0]))[target_labels].to(device) i, _ = torch.max((1-one_hot_labels)*outputs, dim=1) j = torch.masked_select(outputs, one_hot_labels.bool()) cost = -torch.clamp((i-j), min=0) # -self.kappa=0 cost = cost.sum() else: one_hot_labels = torch.eye(len(outputs[0]))[labels].to(device) i, _ = torch.max((1-one_hot_labels)*outputs, dim=1) j = torch.masked_select(outputs, one_hot_labels.bool()) cost = -torch.clamp((j-i), min=0) # -self.kappa=0 cost = cost.sum() return cost
[docs]def margin_loss(outputs, labels, target_labels, targeted, device): '''Define the margin loss.''' if targeted: one_hot_labels = torch.eye(len(outputs[0]))[target_labels].to(device) i, _ = torch.max((1-one_hot_labels)*outputs, dim=1) j = torch.masked_select(outputs, one_hot_labels.bool()) cost = -torch.clamp((i-j), min=0) # -self.kappa=0 else: one_hot_labels = torch.eye(len(outputs[0]))[labels].to(device) i, _ = torch.max((1-one_hot_labels)*outputs, dim=1) j = torch.masked_select(outputs, one_hot_labels.bool()) cost = -torch.clamp((j-i), min=0) # -self.kappa=0 return cost.sum()
[docs]def resolve_amp(args, _logger): '''The function to resolve amp parameters for robust training.''' args.amp_version='' # resolve AMP arguments based on PyTorch / Apex availability if args.apex_amp and has_apex: args.amp_version = 'apex' elif args.native_amp and has_native_amp: args.amp_version = 'native' else: _logger.warning("Neither APEX or native Torch AMP is available, using float32. " "Install NVIDA apex or upgrade to PyTorch 1.6")
[docs]def build_loss_scaler(args, _logger): '''The function to build loss scaler for robust training.''' # setup loss scaler amp_autocast = suppress # do nothing loss_scaler = None if args.amp_version == 'apex': model, optimizer = amp.initialize(model, optimizer, opt_level='O1') loss_scaler = ApexScaler() _logger.info('Using NVIDIA APEX AMP. Training in mixed precision.') elif args.amp_version == 'native': amp_autocast = torch.cuda.amp.autocast loss_scaler = NativeScaler() _logger.info('Using native Torch AMP. Training in mixed precision.') else: _logger.info('AMP not enabled. Training in float32.') return amp_autocast, loss_scaler
[docs]def build_loss(args, mixup_fn, num_aug_splits): '''The function to build loss function for robust training.''' if args.jsd_loss: assert num_aug_splits > 1 # JSD only valid with aug splits set train_loss_fn = JsdCrossEntropy(num_splits=num_aug_splits, smoothing=args.smoothing) elif mixup_fn is not None: # smoothing is handled with mixup target transform which outputs sparse, soft targets if args.bce_loss: train_loss_fn = BinaryCrossEntropy(target_threshold=args.bce_target_thresh) else: train_loss_fn = SoftTargetCrossEntropy() elif args.smoothing: if args.bce_loss: train_loss_fn = BinaryCrossEntropy(smoothing=args.smoothing, target_threshold=args.bce_target_thresh) else: train_loss_fn = LabelSmoothingCrossEntropy(smoothing=args.smoothing) else: train_loss_fn = nn.CrossEntropyLoss() train_loss_fn = train_loss_fn.cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() return train_loss_fn, validate_loss_fn