Source code for tormentor.wrap

import torch

from diamond_square import functional_diamond_square
from .deterministic_image_augmentation import SamplingField, AugmentationState
from .spatial_image_augmentation import SpatialImageAugmentation
from .static_image_augmentation import StaticImageAugmentation
from .random import Uniform, Bernoulli
from .spatial_augmentations import Perspective
from .sampling_fileds import create_sampling_field, apply_sampling_field

[docs]class Wrap(SpatialImageAugmentation): r"""Augmentation Wrap. This augmentation acts like many simultaneous elastic transforms with gaussian sigmas set at varius harmonics. Distributions: ``roughness``: Quantification of the local inconsistency of the distortion effect. ``intensity``: Quantification of the intensity of the distortion effect. .. image:: _static/example_images/Wrap.png """ roughness = Uniform(value_range=(.1, .7)) intensity = Uniform(value_range=(.0, 1.)) def generate_batch_state(self, sampling_tensors: SamplingField) -> AugmentationState: batch_sz, height, width = sampling_tensors[0].size() roughness = type(self).roughness(batch_sz, device=sampling_tensors[0].device) intensity = type(self).intensity(batch_sz, device=sampling_tensors[0].device) plasma_sz = (batch_sz, 1, height, width) plasma_x = functional_diamond_square(plasma_sz, roughness=roughness, device=sampling_tensors[0].device) - .5 plasma_y = functional_diamond_square(plasma_sz, roughness=roughness, device=sampling_tensors[0].device) - .5 plasma_x, plasma_y = plasma_x[:,0,:,:], plasma_y[:, 0,:,:] plasma_dx = plasma_x[:, :, 1:] - plasma_x[:, :, :-1] plasma_dy = plasma_y[:, 1:, :] - plasma_y[:, :-1, :] plasma_scale_x = torch.cat([abs(plasma_dx.view(batch_sz, -1).min(dim=1)[0]).view(1, -1), plasma_dx.view(batch_sz, -1).max(dim=1)[0].view(1, -1)], dim=0).max(dim=0)[0] plasma_scale_y = torch.cat([abs(plasma_dy.view(batch_sz, -1).min(dim=1)[0]).view(1, -1), plasma_dy.view(batch_sz, -1).max(dim=1)[0].view(1, -1)], dim=0).max(dim=0)[0] plasma_x /= ((plasma_scale_x.view(-1, 1, 1) * .25 * width) / intensity.view(-1, 1, 1)) plasma_y /= ((plasma_scale_y.view(-1, 1, 1) * .25 * height) / intensity.view(-1, 1, 1)) return plasma_x, plasma_y @classmethod def functional_sampling_field(cls, sampling_field: SamplingField, plasma_x: torch.FloatTensor, plasma_y: torch.FloatTensor) -> SamplingField: field_x, field_y = sampling_field return field_x + plasma_x[:, :, :], field_y + plasma_y[:, :, :]
# class Shred(StaticImageAugmentation): # r"""Augmentation Shred. # # # Distributions: # ``roughness``: Quantification of the local inconsistency of the distortion effect. # ``erase_percentile``: Quantification of the surface that will be erased. # ``inside``: If True # # .. image:: _static/example_images/Shred.png # """ # roughness = Uniform(value_range=(.4, .8)) # inside = Bernoulli(prob=.5) # erase_percentile = Uniform(value_range=(.0, .5)) # # def generate_batch_state(self, image_batch: torch.Tensor) -> AugmentationState: # batch_sz, _, width, height = image_batch.size() # roughness = type(self).roughness(batch_sz, device=image_batch.device) # plasma_sz = (batch_sz, 1, width, height) # plasma = functional_diamond_square(plasma_sz, roughness=roughness, device=image_batch.device) # inside = type(self).inside(batch_sz, device=image_batch.device).float() # erase_percentile = type(self).erase_percentile(batch_sz, device=image_batch.device) # return plasma, inside, erase_percentile # # @classmethod # def functional_image(cls, image_batch: torch.Tensor, plasma: torch.FloatTensor, inside: torch.FloatTensor, # erase_percentile: torch.FloatTensor) -> torch.Tensor: # inside = inside.view(-1, 1, 1, 1) # erase_percentile = erase_percentile.view(-1, 1, 1, 1) # plasma = inside * plasma + (1 - inside) * (1 - plasma) # plasma_pixels = plasma.view(plasma.size(0), -1) # thresholds = [] # for n in range(plasma_pixels.size(0)): # thresholds.append(torch.kthvalue(plasma_pixels[n], int(plasma_pixels.size(1) * erase_percentile[n]))[0]) # thresholds = torch.Tensor(thresholds).view(-1, 1, 1, 1).to(plasma.device) # erase = (plasma < thresholds).float() # return image_batch * (1 - erase) # class ShredInside(StaticImageAugmentation): r"""Augmentation Shred. Distributions: ``roughness``: Quantification of the local inconsistency of the distortion effect. ``erase_percentile``: Quantification of the surface that will be erased. ``inside``: If True .. image:: _static/example_images/Shred.png """ roughness = Uniform(value_range=(.4, .8)) erase_percentile = Uniform(value_range=(.0, .2)) def generate_batch_state(self, image_batch: torch.Tensor) -> AugmentationState: batch_sz, _, width, height = image_batch.size() roughness = type(self).roughness(batch_sz, device=image_batch.device) plasma_sz = (batch_sz, 1, width, height) plasma = functional_diamond_square(plasma_sz, roughness=roughness, device=image_batch.device) erase_percentile = type(self).erase_percentile(batch_sz, device=image_batch.device) return plasma, erase_percentile @classmethod def functional_image(cls, image_batch: torch.Tensor, plasma: torch.FloatTensor, erase_percentile: torch.FloatTensor) -> torch.Tensor: #inside = inside.view(-1, 1, 1, 1) erase_percentile = erase_percentile.view(-1, 1, 1, 1) #plasma = inside * plasma + (1 - inside) * (1 - plasma) plasma_pixels = plasma.view(plasma.size(0), -1) thresholds = [] for n in range(plasma_pixels.size(0)): thresholds.append(torch.kthvalue(plasma_pixels[n], int(plasma_pixels.size(1) * erase_percentile[n]))[0]) thresholds = torch.Tensor(thresholds).view(-1, 1, 1, 1).to(plasma.device) erase = (plasma < thresholds).float() return image_batch * (1 - erase) class ShredOutside(StaticImageAugmentation): r"""Augmentation Shred. Distributions: ``roughness``: Quantification of the local inconsistency of the distortion effect. ``erase_percentile``: Quantification of the surface that will be erased. ``inside``: If True .. image:: _static/example_images/Shred.png """ roughness = Uniform(value_range=(.1, .7)) erase_percentile = Uniform(value_range=(.0, .5)) def generate_batch_state(self, image_batch: torch.Tensor) -> AugmentationState: batch_sz, _, width, height = image_batch.size() roughness = type(self).roughness(batch_sz, device=image_batch.device) plasma_sz = (batch_sz, 1, width, height) plasma = functional_diamond_square(plasma_sz, roughness=roughness, device=image_batch.device) erase_percentile = type(self).erase_percentile(batch_sz, device=image_batch.device) return plasma, erase_percentile @classmethod def functional_image(cls, image_batch: torch.Tensor, plasma: torch.FloatTensor, erase_percentile: torch.FloatTensor) -> torch.Tensor: #inside = inside.view(-1, 1, 1, 1) erase_percentile = erase_percentile.view(-1, 1, 1, 1) #plasma = inside * plasma + (1 - inside) * (1 - plasma) plasma_pixels = 1 - plasma.view(plasma.size(0), -1) thresholds = [] for n in range(plasma_pixels.size(0)): thresholds.append(torch.kthvalue(plasma_pixels[n], int(plasma_pixels.size(1) * erase_percentile[n]))[0]) thresholds = torch.Tensor(thresholds).view(-1, 1, 1, 1).to(plasma.device) erase = (plasma < thresholds).float() return image_batch * (1 - erase)