Shortcuts

Source code for torchgeo.trainers.moco

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""MoCo trainer for self-supervised learning (SSL)."""

import os
import warnings
from collections.abc import Sequence
from typing import Any

import kornia.augmentation as K
import lightning
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from lightly.loss import NTXentLoss
from lightly.models.modules import MoCoProjectionHead
from lightly.models.utils import deactivate_requires_grad, update_momentum
from lightly.utils.scheduler import cosine_schedule
from torch import Tensor
from torch.optim import SGD, AdamW, Optimizer
from torch.optim.lr_scheduler import (
    CosineAnnealingLR,
    LinearLR,
    LRScheduler,
    MultiStepLR,
    SequentialLR,
)
from torchvision.models._api import WeightsEnum

import torchgeo.transforms as T

from ..models import get_weight
from . import utils
from .base import BaseTask


def moco_augmentations(
    version: int, size: int, weights: Tensor
) -> tuple[nn.Module, nn.Module]:
    """Data augmentations used by MoCo.

    Args:
        version: Version of MoCo.
        size: Size of patch to crop.
        weights: Weight vector for grayscale computation.

    Returns:
        Data augmentation pipelines.
    """
    # https://github.com/facebookresearch/moco/blob/main/main_moco.py#L326
    # https://github.com/facebookresearch/moco-v3/blob/main/main_moco.py#L261
    ks = size // 10 // 2 * 2 + 1
    if version == 1:
        # Same as InstDict: https://arxiv.org/abs/1805.01978
        aug1 = aug2 = K.AugmentationSequential(
            K.RandomResizedCrop(size=(size, size), scale=(0.2, 1)),
            T.RandomGrayscale(weights=weights, p=0.2),
            # Not appropriate for multispectral imagery, seasonal contrast used instead
            # K.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.4, p=1)
            K.RandomBrightness(brightness=(0.6, 1.4), p=1.0),
            K.RandomContrast(contrast=(0.6, 1.4), p=1.0),
            K.RandomHorizontalFlip(),
            K.RandomVerticalFlip(),  # added
            data_keys=['input'],
        )
    elif version == 2:
        # Similar to SimCLR: https://arxiv.org/abs/2002.05709
        aug1 = aug2 = K.AugmentationSequential(
            K.RandomResizedCrop(size=(size, size), scale=(0.2, 1)),
            # Not appropriate for multispectral imagery, seasonal contrast used instead
            # K.ColorJitter(
            #     brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.8
            # )
            K.RandomBrightness(brightness=(0.6, 1.4), p=0.8),
            K.RandomContrast(contrast=(0.6, 1.4), p=0.8),
            T.RandomGrayscale(weights=weights, p=0.2),
            K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=0.5),
            K.RandomHorizontalFlip(),
            K.RandomVerticalFlip(),  # added
            data_keys=['input'],
        )
    else:
        # Same as BYOL: https://arxiv.org/abs/2006.07733
        aug1 = K.AugmentationSequential(
            K.RandomResizedCrop(size=(size, size), scale=(0.08, 1)),
            # Not appropriate for multispectral imagery, seasonal contrast used instead
            # K.ColorJitter(
            #     brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
            # )
            K.RandomBrightness(brightness=(0.6, 1.4), p=0.8),
            K.RandomContrast(contrast=(0.6, 1.4), p=0.8),
            T.RandomGrayscale(weights=weights, p=0.2),
            K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=1),
            K.RandomHorizontalFlip(),
            K.RandomVerticalFlip(),  # added
            data_keys=['input'],
        )
        aug2 = K.AugmentationSequential(
            K.RandomResizedCrop(size=(size, size), scale=(0.08, 1)),
            # Not appropriate for multispectral imagery, seasonal contrast used instead
            # K.ColorJitter(
            #     brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1, p=0.8
            # )
            K.RandomBrightness(brightness=(0.6, 1.4), p=0.8),
            K.RandomContrast(contrast=(0.6, 1.4), p=0.8),
            T.RandomGrayscale(weights=weights, p=0.2),
            K.RandomGaussianBlur(kernel_size=(ks, ks), sigma=(0.1, 2), p=0.1),
            K.RandomSolarize(p=0.2),
            K.RandomHorizontalFlip(),
            K.RandomVerticalFlip(),  # added
            data_keys=['input'],
        )
    return aug1, aug2


[docs]class MoCoTask(BaseTask): """MoCo: Momentum Contrast. Reference implementations: * https://github.com/facebookresearch/moco * https://github.com/facebookresearch/moco-v3 If you use this trainer in your research, please cite the following papers: * v1: https://arxiv.org/abs/1911.05722 * v2: https://arxiv.org/abs/2003.04297 * v3: https://arxiv.org/abs/2104.02057 .. versionadded:: 0.5 """ ignore = ('weights', 'augmentation1', 'augmentation2') monitor = 'train_loss'
[docs] def __init__( self, model: str = 'resnet50', weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, version: int = 3, layers: int = 3, hidden_dim: int = 4096, output_dim: int = 256, lr: float = 9.6, weight_decay: float = 1e-6, momentum: float = 0.9, schedule: Sequence[int] = [120, 160], temperature: float = 1, memory_bank_size: int = 0, moco_momentum: float = 0.99, gather_distributed: bool = False, size: int = 224, grayscale_weights: Tensor | None = None, augmentation1: nn.Module | None = None, augmentation2: nn.Module | None = None, ) -> None: """Initialize a new MoCoTask instance. Args: model: Name of the `timm <https://huggingface.co/docs/timm/reference/models>`__ model to use. weights: Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. in_channels: Number of input channels to model. version: Version of MoCo, 1--3. layers: Number of layers in projection head (not used in v1, 2 for v1/2, 3 for v3). hidden_dim: Number of hidden dimensions in projection head (not used in v1, 2048 for v2, 4096 for v3). output_dim: Number of output dimensions in projection head (not used in v1, 128 for v2, 256 for v3). lr: Learning rate (0.03 x batch_size / 256 for v1/2, 0.6 x batch_size / 256 for v3). weight_decay: Weight decay coefficient (1e-4 for v1/2, 1e-6 for v3). momentum: Momentum of SGD solver (v1/2 only). schedule: Epochs at which to drop lr by 10x (v1/2 only). temperature: Temperature used in InfoNCE loss (0.07 for v1/2, 1 for v3). memory_bank_size: Size of memory bank (65536 for v1/2, 0 for v3). moco_momentum: MoCo momentum of updating key encoder (0.999 for v1/2, 0.99 for v3) gather_distributed: Gather negatives from all GPUs during distributed training (ignored if memory_bank_size > 0). size: Size of patch to crop. grayscale_weights: Weight vector for grayscale computation, see :class:`~torchgeo.transforms.RandomGrayscale`. Only used when ``augmentations=None``. Defaults to average of all bands. augmentation1: Data augmentation for 1st branch. Defaults to MoCo augmentation. augmentation2: Data augmentation for 2nd branch. Defaults to MoCo augmentation. Raises: AssertionError: If an invalid version of MoCo is requested. Warns: UserWarning: If hyperparameters do not match MoCo version requested. """ # Validate hyperparameters assert version in range(1, 4) if version == 1: if memory_bank_size == 0: warnings.warn('MoCo v1 uses a memory bank') elif version == 2: if layers > 2: warnings.warn('MoCo v2 only uses 2 layers in its projection head') if memory_bank_size == 0: warnings.warn('MoCo v2 uses a memory bank') elif version == 3: if layers == 2: warnings.warn('MoCo v3 uses 3 layers in its projection head') if memory_bank_size > 0: warnings.warn('MoCo v3 does not use a memory bank') self.weights = weights super().__init__() grayscale_weights = grayscale_weights or torch.ones(in_channels) aug1, aug2 = moco_augmentations(version, size, grayscale_weights) self.augmentation1 = augmentation1 or aug1 self.augmentation2 = augmentation2 or aug2
[docs] def configure_models(self) -> None: """Initialize the model.""" model: str = self.hparams['model'] weights = self.weights in_channels: int = self.hparams['in_channels'] version: int = self.hparams['version'] layers: int = self.hparams['layers'] hidden_dim: int = self.hparams['hidden_dim'] output_dim: int = self.hparams['output_dim'] # Create backbone self.backbone = timm.create_model( model, in_chans=in_channels, num_classes=0, pretrained=weights is True ) self.backbone_momentum = timm.create_model( model, in_chans=in_channels, num_classes=0, pretrained=weights is True ) deactivate_requires_grad(self.backbone_momentum) # Load weights if weights and weights is not True: if isinstance(weights, WeightsEnum): state_dict = weights.get_state_dict(progress=True) elif os.path.exists(weights): _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) utils.load_state_dict(self.backbone, state_dict) # Create projection (and prediction) head batch_norm = version == 3 if version > 1: input_dim = self.backbone.num_features self.projection_head = MoCoProjectionHead( input_dim, hidden_dim, output_dim, layers, batch_norm=batch_norm ) self.projection_head_momentum = MoCoProjectionHead( input_dim, hidden_dim, output_dim, layers, batch_norm=batch_norm ) deactivate_requires_grad(self.projection_head_momentum) if version == 3: self.prediction_head = MoCoProjectionHead( output_dim, hidden_dim, output_dim, num_layers=2, batch_norm=batch_norm ) # Initialize moving average of output self.avg_output_std = 0.0
[docs] def configure_losses(self) -> None: """Initialize the loss criterion.""" try: self.criterion = NTXentLoss( self.hparams['temperature'], (self.hparams['memory_bank_size'], self.hparams['output_dim']), self.hparams['gather_distributed'], ) except TypeError: # lightly 1.4.24 and older self.criterion = NTXentLoss( self.hparams['temperature'], self.hparams['memory_bank_size'], self.hparams['gather_distributed'], )
[docs] def configure_optimizers( self, ) -> 'lightning.pytorch.utilities.types.OptimizerLRScheduler': """Initialize the optimizer and learning rate scheduler. Returns: Optimizer and learning rate scheduler. """ if self.hparams['version'] == 3: optimizer: Optimizer = AdamW( params=self.parameters(), lr=self.hparams['lr'], weight_decay=self.hparams['weight_decay'], ) warmup_epochs = 40 max_epochs = 200 if self.trainer and self.trainer.max_epochs: max_epochs = self.trainer.max_epochs scheduler: LRScheduler = SequentialLR( optimizer, schedulers=[ LinearLR( optimizer, start_factor=1 / warmup_epochs, total_iters=warmup_epochs, ), CosineAnnealingLR(optimizer, T_max=max_epochs), ], milestones=[warmup_epochs], ) else: optimizer = SGD( params=self.parameters(), lr=self.hparams['lr'], momentum=self.hparams['momentum'], weight_decay=self.hparams['weight_decay'], ) scheduler = MultiStepLR( optimizer=optimizer, milestones=self.hparams['schedule'] ) return { 'optimizer': optimizer, 'lr_scheduler': {'scheduler': scheduler, 'monitor': self.monitor}, }
[docs] def forward(self, x: Tensor) -> tuple[Tensor, Tensor]: """Forward pass of the model. Args: x: Mini-batch of images. Returns: Output of the model and backbone """ h: Tensor = self.backbone(x) q = h if self.hparams['version'] > 1: q = self.projection_head(q) if self.hparams['version'] == 3: q = self.prediction_head(q) return q, h
[docs] def forward_momentum(self, x: Tensor) -> Tensor: """Forward pass of the momentum model. Args: x: Mini-batch of images. Returns: Output from the momentum model. """ k: Tensor = self.backbone_momentum(x) if self.hparams['version'] > 1: k = self.projection_head_momentum(k) return k
[docs] def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. Returns: The loss tensor. """ x = batch['image'] batch_size = x.shape[0] in_channels = self.hparams['in_channels'] assert x.size(1) == in_channels or x.size(1) == 2 * in_channels if x.size(1) == in_channels: x1 = x x2 = x else: x1 = x[:, :in_channels] x2 = x[:, in_channels:] with torch.no_grad(): x1 = self.augmentation1(x1) x2 = self.augmentation2(x2) m = self.hparams['moco_momentum'] if self.hparams['version'] == 1: q, h1 = self.forward(x1) with torch.no_grad(): update_momentum(self.backbone, self.backbone_momentum, m) k = self.forward_momentum(x2) loss: Tensor = self.criterion(q, k) elif self.hparams['version'] == 2: q, h1 = self.forward(x1) with torch.no_grad(): update_momentum(self.backbone, self.backbone_momentum, m) update_momentum(self.projection_head, self.projection_head_momentum, m) k = self.forward_momentum(x2) loss = self.criterion(q, k) if self.hparams['version'] == 3: m = cosine_schedule(self.current_epoch, self.trainer.max_epochs, m, 1) q1, h1 = self.forward(x1) q2, h2 = self.forward(x2) with torch.no_grad(): update_momentum(self.backbone, self.backbone_momentum, m) update_momentum(self.projection_head, self.projection_head_momentum, m) k1 = self.forward_momentum(x1) k2 = self.forward_momentum(x2) loss = self.criterion(q1, k2) + self.criterion(q2, k1) # Calculate the mean normalized standard deviation over features dimensions. # If this is << 1 / sqrt(h1.shape[1]), then the model is not learning anything. output = h1.detach() output = F.normalize(output, dim=1) output_std = torch.std(output, dim=0) output_std = torch.mean(output_std, dim=0) self.avg_output_std = 0.9 * self.avg_output_std + (1 - 0.9) * output_std.item() self.log('train_ssl_std', self.avg_output_std, batch_size=batch_size) self.log('train_loss', loss, batch_size=batch_size) return loss
[docs] def validation_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: """No-op, does nothing."""
[docs] def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing."""
[docs] def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """No-op, does nothing."""

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources