import kornia as K
import torch
from diamond_square import functional_diamond_square
from .deterministic_image_augmentation import AugmentationState
from .static_image_augmentation import StaticImageAugmentation
from .random import Uniform, Normal, Bernoulli
[docs]class ColorAugmentation(StaticImageAugmentation):
r"""Abstract class for all augmentations manipulating the colorspace.
All augmentations inheriting ``ColorAugmentation``, expect 3-channel inputs that can be interpreted as RGB in the
range [0., 1.]. If the channels are neither 3 or 1, the augmentation becomes an identity.
The subclasses should only define ``generate_batch_state(self, batch: torch.FloatTensor)`` and classmethod
``functional_image(cls, batch: torch.FloatTensor, *batch_state)``.
"""
def forward_img(self, batch_tensor: torch.FloatTensor) -> torch.FloatTensor:
state = self.generate_batch_state(batch_tensor)
if batch_tensor.size(1) == 3: # Color operations require
result = type(self).functional_image(*((batch_tensor,) + state))
return result
if batch_tensor.size(1) == 1: # Color operations require
batch_tensor = batch_tensor.repeat([1, 3, 1, 1])
batch_tensor = type(self).functional_image(*((batch_tensor,) + state))
return K.color.rgb_to_grayscale(batch_tensor)
else: # No colors were in the image, it will be ignored
return batch_tensor
[docs]class Invert(ColorAugmentation):
r"""Performs color inversion in HSV colorspace for some images randomly selected.
.. image:: _static/example_images/Invert.png
"""
do_inversion = Bernoulli()
def generate_batch_state(self, batch: torch.FloatTensor) -> AugmentationState:
do_inversion = type(self).do_inversion(batch.size(0), device=batch.device).view(-1).float()
return do_inversion,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, do_inversion: torch.FloatTensor) -> torch.FloatTensor:
do_inversion = do_inversion.view(-1, 1, 1, 1)
#print(do_inversion)
hsv_batch = K.color.rgb_to_hsv(batch)
hsv_batch[:, 2:, :, :] = (1 - hsv_batch[:, 2:, :, :]) * do_inversion + hsv_batch[:, 2:, :, :] * (1 - do_inversion)
out_batch = K.color.hsv_to_rgb(hsv_batch)
return out_batch
class InvertLuminance(ColorAugmentation):
r"""Performs color inversion in HSV colorspace for some images randomly selected.
.. image:: _static/example_images/Invert.png
"""
def generate_batch_state(self, batch: torch.FloatTensor) -> AugmentationState:
return ()
@classmethod
def functional_image(cls, batch: torch.FloatTensor) -> torch.FloatTensor:
hsv_batch = K.color.rgb_to_hsv(batch)
hsv_batch[:, 2:, :, :] = (1 - hsv_batch[:, 2:, :, :])
out_batch = K.color.hsv_to_rgb(hsv_batch)
return out_batch
class LinearColor(ColorAugmentation):
r"""Changes the brightness of the image.
.. image:: _static/example_images/Brightness.png
"""
a = Uniform((.6, 1.))
b = Uniform((.0, 1.))
def generate_batch_state(self, batch: torch.FloatTensor) -> AugmentationState:
a = type(self).a(batch.size(0), device=batch.device).view(-1)
b = type(self).b(batch.size(0), device=batch.device).view(-1)
return a, (1-a)*b
@classmethod
def functional_image(cls, batch: torch.FloatTensor, a: torch.FloatTensor, b: torch.FloatTensor) -> torch.FloatTensor:
return batch * a.view(-1,1,1,1) + b.view(-1,1,1,1)
class KorniaBrightness(ColorAugmentation):
r"""Changes the brightness of the image.
.. image:: _static/example_images/Brightness.png
"""
brightness = Uniform((-1.0, 1.0))
def generate_batch_state(self, batch: torch.FloatTensor) -> AugmentationState:
brightness = type(self).brightness(batch.size(0), device=batch.device).view(-1)
return brightness,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, brightness: torch.FloatTensor) -> torch.FloatTensor:
return K.enhance.adjust_brightness(batch, brightness)
class KorniaContrast(ColorAugmentation):
r"""Changes the contrast of the image.
.. image:: _static/example_images/Contrast.png
"""
contrast = Uniform((0.0, 1.0))
def generate_batch_state(self, batch: torch.FloatTensor) -> AugmentationState:
contrast = type(self).contrast(batch.size(0), device=batch.device).view(-1)
return contrast,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, contrast: torch.FloatTensor) -> torch.FloatTensor:
# contrast = contrast.view(-1, 1, 1, 1)
return K.enhance.adjust_saturation(batch, contrast)
[docs]class Brightness(ColorAugmentation):
r"""Changes the brightness of the image.
.. image:: _static/example_images/Brightness.png
"""
brightness = Uniform((-1.0, 1.0))
def generate_batch_state(self, batch: torch.FloatTensor) -> AugmentationState:
brightness = type(self).brightness(batch.size(0), device=batch.device).view(-1)
return brightness,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, brightness: torch.FloatTensor) -> torch.FloatTensor:
return torch.clamp(batch + brightness.view([-1, 1, 1, 1]), 0., 1.)
[docs]class Contrast(ColorAugmentation):
r"""Changes the contrast of the image.
.. image:: _static/example_images/Contrast.png
"""
contrast = Uniform((0.0, 2.0))
def generate_batch_state(self, batch: torch.FloatTensor) -> AugmentationState:
contrast = type(self).contrast(batch.size(0), device=batch.device).view(-1)
return contrast,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, contrast: torch.FloatTensor) -> torch.FloatTensor:
# contrast = contrast.view(-1, 1, 1, 1)
return torch.clamp(batch*contrast.view([-1, 1, 1, 1]), 0.0, 1.0)
[docs]class Saturation(ColorAugmentation):
r"""Changes the saturation of the image.
.. image:: _static/example_images/Saturation.png
"""
saturation = Uniform((0.0, 2.0))
def generate_batch_state(self, batch: torch.FloatTensor) -> AugmentationState:
saturation = type(self).saturation(batch.size(0), device=batch.device).view(-1)
return saturation,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, saturation: torch.FloatTensor) -> torch.FloatTensor:
return K.enhance.adjust_saturation(batch, saturation)
[docs]class Hue(ColorAugmentation):
r"""Changes the Hue of the image.
.. image:: _static/example_images/Hue.png
"""
hue = Uniform((-.5, .5))
def generate_batch_state(self, batch: torch.FloatTensor) -> AugmentationState:
hue = type(self).hue(batch.size(0), device=batch.device).view(-1)
return hue,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, hue: torch.FloatTensor) -> torch.FloatTensor:
# hue = hue.view(-1, 1, 1, 1)
return K.enhance.adjust_hue(batch, hue)
[docs]class ColorJitter(ColorAugmentation):
r"""Changes hue, contrast, saturation, and brightness of the image.
.. image:: _static/example_images/ColorJitter.png
"""
hue = Uniform((-.5, .5))
contrast = Uniform((0.0, 1.0))
saturation = Uniform((0.0, 2.0))
brightness = Uniform((-1.0, 1.0))
def generate_batch_state(self, batch: torch.FloatTensor) -> AugmentationState:
hue = type(self).contrast(batch.size(0), device=batch.device).view(-1)
contrast = type(self).contrast(batch.size(0), device=batch.device).view(-1)
saturation = type(self).saturation(batch.size(0), device=batch.device).view(-1)
brightness = type(self).brightness(batch.size(0), device=batch.device).view(-1)
return hue, contrast, saturation, brightness
@classmethod
def functional_image(cls, batch: torch.FloatTensor, hue: torch.FloatTensor, contrast: torch.FloatTensor,
saturation: torch.FloatTensor, brightness: torch.FloatTensor) -> torch.FloatTensor:
batch = K.enhance.adjust_hue(batch, hue)
batch = K.enhance.adjust_saturation(batch, saturation)
batch = K.enhance.adjust_brightness(batch, brightness)
batch = K.enhance.adjust_contrast(batch, contrast)
return batch
class GaussianAdditiveNoise(ColorAugmentation):
r"""Lowers the brightness of the image over a random mask.
.. image:: _static/example_images/PlasmaShadow.png
"""
noise = Normal(mean=0, deviation=.2)
def generate_batch_state(self, batch_tensor: torch.FloatTensor) -> torch.FloatTensor:
tensor_sz = batch_tensor.size()
noise = type(self).noise(tensor_sz, device=batch_tensor.device)
return noise,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, noise: torch.FloatTensor) -> torch.FloatTensor:
return torch.clamp(batch + noise, 0, 1)
[docs]class PlasmaBrightness(ColorAugmentation):
r"""Changes the brightness of the image locally.
.. image:: _static/example_images/PlasmaBrightness.png
"""
roughness = Uniform(value_range=(.1, .7))
intensity = Uniform(value_range=(0., 1.))
def generate_batch_state(self, batch_tensor: torch.FloatTensor) -> torch.FloatTensor:
batch_sz, channels, height, width = batch_tensor.size()
roughness = type(self).roughness(batch_sz, device=batch_tensor.device)
plasma_sz = (batch_sz, 1, height, width)
intensity = type(self).intensity(batch_sz, device=batch_tensor.device).view(-1, 1, 1, 1)
brightness_map = 2 * functional_diamond_square(plasma_sz, roughness=roughness, device=batch_tensor.device) - 1
return brightness_map * intensity,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, brightness_map: torch.FloatTensor) -> torch.FloatTensor:
return torch.clamp(batch + brightness_map, 0, 1)
[docs]class PlasmaRgbBrightness(ColorAugmentation):
r"""Changes the saturation of the image.
.. image:: _static/example_images/Saturation.png
"""
roughness = Uniform(value_range=(.1, .7))
intensity = Uniform(value_range=(0., 1.))
def generate_batch_state(self, batch_tensor: torch.FloatTensor) -> torch.FloatTensor:
batch_sz, channels, height, width = batch_tensor.size()
roughness = type(self).roughness(batch_sz, device=batch_tensor.device)
plasma_sz = (batch_sz, 3, height, width)
intensity = type(self).intensity(batch_sz, device=batch_tensor.device).view(-1, 1, 1, 1)
brightness_map = 2 * functional_diamond_square(plasma_sz, roughness=roughness, device=batch_tensor.device) - 1
return brightness_map * intensity,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, brightness_map: torch.FloatTensor) -> torch.FloatTensor:
# brightness = brightness.view(-1, 1, 1, 1)
return torch.clamp(batch + brightness_map, 0, 1)
class PlasmaLinearColor(ColorAugmentation):
r"""Changes the saturation of the image.
.. image:: _static/example_images/PlasmaLinearColor.png
"""
roughness = Uniform(value_range=(.1, .4))
alpha_range = Uniform(value_range=(.0, 1.))
alpha_mean = Uniform(value_range=(.0, 1.))
beta_range = Uniform(value_range=(0., 1.))
beta_mean = Uniform(value_range=(0., 1.))
def generate_batch_state(self, batch_tensor: torch.FloatTensor) -> torch.FloatTensor:
batch_sz, channels, height, width = batch_tensor.size()
roughness = type(self).roughness(batch_sz, device=batch_tensor.device)
plasma_sz = (batch_sz, 1, height, width)
alpha_range = type(self).alpha_range(batch_sz, device=batch_tensor.device).view(-1,1,1,1)
alpha_mean = type(self).alpha_mean(batch_sz, device=batch_tensor.device).view(-1,1,1,1)
beta_range = type(self).beta_range(batch_sz, device=batch_tensor.device).view(-1,1,1,1)
beta_mean = type(self).beta_mean(batch_sz, device=batch_tensor.device).view(-1,1,1,1)
alpha_plasma = functional_diamond_square(plasma_sz, roughness=roughness, device=batch_tensor.device)
#print("RndAlpha:",alpha_plasma.min().item(),alpha_plasma.max().item())
alpha_plasma = (alpha_plasma * alpha_range) + (1-alpha_range) * alpha_mean
#print("RndAlpha2:",alpha_plasma.min().item(),alpha_plasma.max().item())
beta_plasma = functional_diamond_square(plasma_sz, roughness=roughness, device=batch_tensor.device)
#print("RndBeta:",beta_plasma.min().item(),beta_plasma.max().item())
beta_plasma = (beta_plasma * beta_range) + (1-beta_range) * beta_mean + beta_range * .5
#print("RndBeta2:",beta_plasma.min().item(),beta_plasma.max().item())
#beta_available = (1 - alpha_plasma)
#beta_alpha = beta_available * beta_alpha
#beta_beta = (beta_available - beta_alpha) * beta_beta
#beta_plasma = functional_diamond_square(plasma_sz, roughness=roughness, device=batch_tensor.device)
#beta_plasma = beta_plasma * beta_alpha + beta_beta
#print ("Ranges:",(alpha_plasma+beta_plasma).min().item(),(alpha_plasma+beta_plasma).max().item())
plasma_sum = (alpha_plasma+beta_plasma)
alpha_plasmam, beta_plasma = (alpha_plasma/plasma_sum,beta_plasma/plasma_sum)
return alpha_plasma, beta_plasma,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, alpha_plasma: torch.FloatTensor, beta_plasma: torch.FloatTensor) -> torch.FloatTensor:
scaled_color_img = batch * alpha_plasma + beta_plasma
return torch.clamp(scaled_color_img, 0, 1)
[docs]class PlasmaContrast(ColorAugmentation):
r"""Changes the contrast of the image locally.
.. image:: _static/example_images/PlasmaContrast.png
"""
roughness = Uniform(value_range=(.1, .7))
def generate_batch_state(self, batch_tensor: torch.FloatTensor) -> torch.FloatTensor:
batch_sz, channels, height, width = batch_tensor.size()
roughness = type(self).roughness(batch_sz, device=batch_tensor.device)
plasma_sz = (batch_sz, 1, height, width)
contrast_map = 4 * functional_diamond_square(plasma_sz, roughness=roughness, device=batch_tensor.device)
return contrast_map,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, contrast_map: torch.FloatTensor) -> torch.FloatTensor:
return torch.clamp((batch - .5) * contrast_map + .5, 0, 1)
[docs]class PlasmaShadow(ColorAugmentation):
r"""Lowers the brightness of the image over a random mask.
.. image:: _static/example_images/PlasmaShadow.png
"""
roughness = Uniform(value_range=(.1, .7))
shade_intensity = Uniform(value_range=(-1.0, .0))
shade_quantity = Uniform(value_range=(0.0, 1.0))
def generate_batch_state(self, batch_tensor: torch.FloatTensor) -> torch.FloatTensor:
batch_sz, channels, height, width = batch_tensor.size()
roughness = type(self).roughness(batch_sz, device=batch_tensor.device)
shade_intensity = type(self).shade_intensity(batch_sz, device=batch_tensor.device).view(-1, 1, 1, 1)
shade_quantity = type(self).shade_quantity(batch_sz, device=batch_tensor.device).view(-1, 1, 1, 1)
plasma_sz = (batch_sz, 1, height, width)
shade_map = functional_diamond_square(plasma_sz, roughness=roughness, device=batch_tensor.device)
shade_map = (shade_map < shade_quantity).float() * shade_intensity
return shade_map,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, shade_map: torch.FloatTensor) -> torch.FloatTensor:
return torch.clamp(batch + shade_map, 0, 1)
class PlasmaGaussianAdditiveNoise(ColorAugmentation):
r"""Lowers the brightness of the image over a random mask.
.. image:: _static/example_images/PlasmaShadow.png
"""
noise = Normal(mean=0, deviation=.2)
roughness = Uniform(value_range=(.3, .7))
def generate_batch_state(self, batch_tensor: torch.FloatTensor) -> torch.FloatTensor:
batch_sz, channels, height, width = batch_tensor.size()
roughness = type(self).roughness(batch_sz, device=batch_tensor.device)
tensor_sz = batch_tensor.size()
noise = type(self).noise(tensor_sz, device=batch_tensor.device)
plasma_sz = (batch_sz, 1, height, width)
noise_intencity_coefficient = functional_diamond_square(plasma_sz, roughness=roughness, device=batch_tensor.device)
return noise * noise_intencity_coefficient,
@classmethod
def functional_image(cls, batch: torch.FloatTensor, noise: torch.FloatTensor) -> torch.FloatTensor:
return torch.clamp(batch + noise, 0, 1)