from .random import Categorical
from .deterministic_image_augmentation import DeterministicImageAugmentation, SamplingField, PointCloudList, PointCloudsImages
import torch
[docs]class AugmentationChoice(DeterministicImageAugmentation):
r"""Select randomly among many augmentations.
.. figure :: _static/example_images/AugmentationChoice.png
Random choice of perspective and plasma-brightness augmentations
.. code-block :: python
augmentation_factory = tormentor.RandomPerspective ^ tormentor.RandomPlasmaBrightness
augmentation = augmentation_factory()
augmented_image = augmentation(image)
"""
@classmethod
def create(cls, augmentation_list, requires_grad=False, new_cls_name=None):
new_parameters = {"choice": Categorical(len(augmentation_list)), "available_augmentations": augmentation_list}
for augmentation in augmentation_list:
class_name = str(augmentation).split(".")[-1][:-2]
cls_distributions = augmentation.get_distributions()
cls_distributions = {f"{class_name}_{k}": v for k, v in cls_distributions.items()}
new_parameters.update(cls_distributions)
for cls_distribution in cls_distributions.values():
for parameter in cls_distribution.get_distribution_parameters().values():
parameter.requires_grad_(requires_grad)
ridx = cls.__qualname__.rfind("_")
if ridx == -1:
cls_oldname = cls.__qualname__
else:
cls_oldname = cls.__qualname__[:ridx]
if new_cls_name is None:
new_cls_name = f"{cls_oldname}_{torch.randint(1000000, 9000000, (1,)).item()}"
new_cls = type(new_cls_name, (cls,), new_parameters)
return new_cls
def forward_sampling_field(self, coords: SamplingField):
batch_sz = coords[0].size(0)
augmentation_ids = type(self).choice(batch_sz)
augmented_batch_x = []
augmented_batch_y = []
for sample_n in range(batch_sz):
sample_coords = coords[0][sample_n: sample_n + 1, :, :], coords[1][sample_n: sample_n + 1, :, :]
augmentation = type(self).available_augmentations[augmentation_ids[sample_n]]()
sample_x, sample_y = augmentation.forward_sampling_field(sample_coords)
augmented_batch_x.append(sample_x)
augmented_batch_y.append(sample_y)
augmented_batch_x = torch.cat(augmented_batch_x, dim=0)
augmented_batch_y = torch.cat(augmented_batch_y, dim=0)
return augmented_batch_x, augmented_batch_y
def forward_img(self, batch_tensor):
batch_sz = batch_tensor.size(0)
augmentation_ids = type(self).choice(batch_sz)
augmented_batch = []
for sample_n in range(batch_sz):
sample_tensor = batch_tensor[sample_n:sample_n + 1, :, :, :]
augmentation = type(self).available_augmentations[augmentation_ids[sample_n]]()
augmented_sample = augmentation.forward_img(sample_tensor)
augmented_batch.append(augmented_sample)
augmented_batch = torch.cat(augmented_batch, dim=0)
return augmented_batch
def forward_img_path_probabilities(self, batch_tensor: torch.FloatTensor) -> torch.FloatTensor:
batch_sz = batch_tensor.size(0)
augmentation_ids = type(self).choice(batch_sz)
probs = self.choice.probs[0, augmentation_ids]
for sample_n in range(batch_sz):
sample_tensor = batch_tensor[sample_n:sample_n + 1, :, :, :]
augmentation = type(self).available_augmentations[augmentation_ids[sample_n]]()
probs[sample_n] *= augmentation.forward_img_path_probabilities(sample_tensor)[0]
return probs
def forward_mask(self, batch_tensor):
batch_sz = batch_tensor.size(0)
augmentation_ids = type(self).choice(batch_sz)
augmented_batch = []
for sample_n in range(batch_sz):
sample_tensor = batch_tensor[sample_n:sample_n + 1, :, :, :]
augmentation = type(self).available_augmentations[augmentation_ids[sample_n]]()
augmented_sample = augmentation.forward_mask(sample_tensor)
augmented_batch.append(augmented_sample)
augmented_batch = torch.cat(augmented_batch, dim=0)
return augmented_batch
def forward_pointcloud(self, pcl: PointCloudList, batch_tensor: torch.FloatTensor,
compute_img: bool) -> PointCloudsImages:
batch_sz = batch_tensor.size(0)
augmentation_ids = type(self).choice(batch_sz)
augmented_batch = []
augmented_pcl = []
for sample_n in range(batch_sz):
sample_tensor = batch_tensor[sample_n:sample_n + 1, :, :, :]
pc_onelist = pcl[sample_n: sample_n + 1]
augmentation = type(self).available_augmentations[augmentation_ids[sample_n]]()
aug_pc_onelist, augmented_sample = augmentation.forward_pointcloud(pc_onelist, sample_tensor, compute_img)
augmented_batch.append(augmented_sample)
augmented_pcl = augmented_pcl + aug_pc_onelist
if compute_img:
augmented_batch = torch.cat(augmented_batch, dim=0)
return augmented_pcl, augmented_batch
else:
return augmented_pcl, batch_tensor