Shortcuts

Source code for torchgeo.transforms.spatial

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""TorchGeo augmentations."""

from typing import Any

import kornia.augmentation as K
import torch
from kornia.augmentation.random_generator import PlainUniformGenerator
from torch import Tensor


[docs]class SatSlideMix(K.GeometricAugmentationBase2D): """Applies the Sat-SlideMix augmentation to a batch of images and masks. Sat-SlideMix rolls (circularly shifts) images along either the height or width axis by a random amount. If you use this method in your research, please cite the following paper: * https://doi.org/10.1609/aaai.v39i27.35028 .. versionadded:: 0.8 """
[docs] def __init__( self, gamma: int = 1, beta: Tensor | float | tuple[float, float] | list[float] = (0.0, 1.0), p: float = 0.5, ) -> None: """Initialize a new SatSlideMix instance. Args: gamma: The number of augmented samples to create for each input image. The output batch size will be gamma * B. beta: The range of percentage (0.0 to 1.0) of the image dimension (height or width) to shift. p: Probability to apply the augmentation on each sample Raises: AssertionError: If `gamma` is not a positive integer. """ super().__init__(p=p) assert isinstance(gamma, int) and gamma > 0, 'gamma must be a positive integer' self._param_generator: PlainUniformGenerator = PlainUniformGenerator( (beta, 'beta', 0.5, (0.0, 1.0)), ((0.0, 1.0), 'dim', 0.5, (0.0, 1.0)), ((0.0, 1.0), 'direction', 0.5, (0.0, 1.0)), ) self.flags = {'gamma': gamma}
[docs] def generate_parameters(self, batch_shape: tuple[int, ...]) -> dict[str, Tensor]: """Generate parameters for the batch.""" B, C, H, W = batch_shape batch_shape = torch.Size((B * self.flags['gamma'], C, H, W)) params: dict[str, Tensor] = self._param_generator( batch_shape, self.same_on_batch ) return params
[docs] def compute_transformation( self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any] ) -> Tensor: """Compute the transformation. Args: input: the input tensor params: generated parameters flags: static parameters Returns: the transformation """ out: Tensor = self.identity_matrix(input) return out
[docs] def apply_transform( self, input: Tensor, params: dict[str, Tensor], flags: dict[str, Any], transform: Tensor | None = None, ) -> Tensor: """Apply the transform to the input image or mask. Args: input: the input tensor image or mask params: generated parameters flags: static parameters transform: the geometric transformation tensor Returns: the augmented input """ directions = (params['direction'].round() * 2.0 - 1.0).to( torch.int ) # convert to -1 or 1 dims = params['dim'].round().to(torch.int) + 2 # convert to 2 or 3 sizes = torch.index_select(torch.tensor(input.shape), dim=0, index=dims) betas = params['beta'] # Repeat each image gamma times (B*gamma, C, H, W) out = input.repeat_interleave(flags['gamma'], dim=0) # It's necessary to roll each image individually if shifts/dims vary # Apply roll to the i-th image along the chosen dimension # Note: We roll out[i] which has shape (C, H, W). # Because out[i] is a 3D tensor, we index using dim - 1 for torch.roll. for i, (beta, dim, direction, size) in enumerate( zip(betas, dims, directions, sizes, strict=True) ): shift = torch.round(beta * size * direction) out[i] = torch.roll(out[i], shifts=int(shift), dims=int(dim) - 1) return out

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources