Source code for tormentor.spatial_image_augmentation

import torch
from .sampling_fileds import SamplingField, apply_sampling_field, create_sampling_field
from .deterministic_image_augmentation import DeterministicImageAugmentation


[docs]class SpatialImageAugmentation(DeterministicImageAugmentation): r"""Parent class for augmentations that move things around. Every class were image pixels move around rather that just change should be a descendant of this class. All classes that do not descend from this class are expected to be neutral for pointclouds, and sampling fields and should be descendants of StaticImageAugmentation. """ @classmethod def functional_sampling_field(cls, coords: SamplingField, *state) -> SamplingField: raise NotImplementedError() def forward_img(self, batch_tensor): batch_size, channels, height, width = batch_tensor.size() sf = create_sampling_field(width, height, batch_size=batch_size, device=batch_tensor.device) sf = self.forward_sampling_field(sf) return apply_sampling_field(batch_tensor, sf) def forward_sampling_field(self, coords: SamplingField) -> SamplingField: state = self.generate_batch_state(coords) return type(self).functional_sampling_field(*((coords,) + state)) def forward_mask(self, X: torch.Tensor) -> torch.Tensor: return self.forward_img(X)