Source code for ares.defense.bit_depth_reduction

from ares.utils.registry import registry


[docs]@registry.register_model('bit_depth_reduction') class BitDepthReduction(object): '''Bit depth reduction defense method.'''
[docs] def __init__(self, device='cuda', compressed_bit=4): ''' Args: device (torch.device): The device to perform autoattack. Defaults to 'cuda'. compressed_bit (int): The compressed bit. ''' self.compressed_bit = compressed_bit self.device = device
def __call__(self, images): '''The function to perform bit depth reduction on the input images.''' images = self.bit_depth_reduction(images) return images
[docs] def bit_depth_reduction(self, xs): bits = 2 ** self.compressed_bit #2**i xs_compress = (xs.detach() * bits).int() xs_255 = (xs_compress * (255 / bits)) xs_compress = (xs_255 / 255).to(self.device) return xs_compress