Shortcuts

Source code for torchgeo.trainers.regression

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""Trainers for regression."""

import os
from typing import Any

import kornia.augmentation as K
import matplotlib.pyplot as plt
import segmentation_models_pytorch as smp
import timm
import torch
import torch.nn as nn
from matplotlib.figure import Figure
from torch import Tensor
from torchmetrics import MeanAbsoluteError, MeanSquaredError, MetricCollection
from torchvision.models._api import WeightsEnum

from ..datasets import RGBBandsMissingError, unbind_samples
from ..models import FCN, get_weight
from . import utils
from .base import BaseTask


[docs]class RegressionTask(BaseTask): """Regression.""" target_key = 'label'
[docs] def __init__( self, model: str = 'resnet50', backbone: str = 'resnet50', weights: WeightsEnum | str | bool | None = None, in_channels: int = 3, num_outputs: int = 1, num_filters: int = 3, loss: str = 'mse', lr: float = 1e-3, patience: int = 10, freeze_backbone: bool = False, freeze_decoder: bool = False, ) -> None: """Initialize a new RegressionTask instance. Args: model: Name of the `timm <https://huggingface.co/docs/timm/reference/models>`__ or `smp <https://smp.readthedocs.io/en/latest/models.html>`__ model to use. backbone: Name of the `timm <https://smp.readthedocs.io/en/latest/encoders_timm.html>`__ or `smp <https://smp.readthedocs.io/en/latest/encoders.html>`__ backbone to use. Only applicable to PixelwiseRegressionTask. weights: Initial model weights. Either a weight enum, the string representation of a weight enum, True for ImageNet weights, False or None for random weights, or the path to a saved model state dict. in_channels: Number of input channels to model. num_outputs: Number of prediction outputs. num_filters: Number of filters. Only applicable when model='fcn'. loss: One of 'mse' or 'mae'. lr: Learning rate for optimizer. patience: Patience for learning rate scheduler. freeze_backbone: Freeze the backbone network to linear probe the regression head. Does not support FCN models. freeze_decoder: Freeze the decoder network to linear probe the regression head. Does not support FCN models. Only applicable to PixelwiseRegressionTask. .. versionchanged:: 0.4 Change regression model support from torchvision.models to timm .. versionadded:: 0.5 The *freeze_backbone* and *freeze_decoder* parameters. .. versionchanged:: 0.5 *learning_rate* and *learning_rate_schedule_patience* were renamed to *lr* and *patience*. """ self.weights = weights super().__init__()
[docs] def configure_models(self) -> None: """Initialize the model.""" # Create model weights = self.weights self.model = timm.create_model( self.hparams['model'], num_classes=self.hparams['num_outputs'], in_chans=self.hparams['in_channels'], pretrained=weights is True, ) # Load weights if weights and weights is not True: if isinstance(weights, WeightsEnum): state_dict = weights.get_state_dict(progress=True) elif os.path.exists(weights): _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) utils.load_state_dict(self.model, state_dict) # Freeze backbone and unfreeze classifier head if self.hparams['freeze_backbone']: for param in self.model.parameters(): param.requires_grad = False for param in self.model.get_classifier().parameters(): param.requires_grad = True
[docs] def configure_losses(self) -> None: """Initialize the loss criterion. Raises: ValueError: If *loss* is invalid. """ loss: str = self.hparams['loss'] if loss == 'mse': self.criterion: nn.Module = nn.MSELoss() elif loss == 'mae': self.criterion = nn.L1Loss() else: raise ValueError( f"Loss type '{loss}' is not valid. " "Currently, supports 'mse' or 'mae' loss." )
[docs] def configure_metrics(self) -> None: """Initialize the performance metrics. * :class:`~torchmetrics.MeanSquaredError`: The average of the squared differences between the predicted and actual values (MSE) and its square root (RMSE). Lower values are better. * :class:`~torchmetrics.MeanAbsoluteError`: The average of the absolute differences between the predicted and actual values (MAE). Lower values are better. """ metrics = MetricCollection( { 'RMSE': MeanSquaredError(squared=False), 'MSE': MeanSquaredError(squared=True), 'MAE': MeanAbsoluteError(), } ) self.train_metrics = metrics.clone(prefix='train_') self.val_metrics = metrics.clone(prefix='val_') self.test_metrics = metrics.clone(prefix='test_')
[docs] def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. Returns: The loss tensor. """ x = batch['image'] batch_size = x.shape[0] # TODO: remove .to(...) once we have a real pixelwise regression dataset y = batch[self.target_key].to(torch.float) y_hat = self(x) if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) loss: Tensor = self.criterion(y_hat, y) self.log('train_loss', loss, batch_size=batch_size) self.train_metrics(y_hat, y) self.log_dict(self.train_metrics, batch_size=batch_size) return loss
[docs] def validation_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Compute the validation loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ x = batch['image'] batch_size = x.shape[0] # TODO: remove .to(...) once we have a real pixelwise regression dataset y = batch[self.target_key].to(torch.float) y_hat = self(x) if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) loss = self.criterion(y_hat, y) self.log('val_loss', loss, batch_size=batch_size) self.val_metrics(y_hat, y) self.log_dict(self.val_metrics, batch_size=batch_size) if ( batch_idx < 10 and hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'plot') and self.logger and hasattr(self.logger, 'experiment') and hasattr(self.logger.experiment, 'add_figure') ): datamodule = self.trainer.datamodule aug = K.AugmentationSequential( K.Denormalize(datamodule.mean, datamodule.std), data_keys=None, keepdim=True, ) batch = aug(batch) if self.target_key == 'mask': y = y.squeeze(dim=1) y_hat = y_hat.squeeze(dim=1) batch['prediction'] = y_hat for key in ['image', self.target_key, 'prediction']: batch[key] = batch[key].cpu() sample = unbind_samples(batch)[0] fig: Figure | None = None try: fig = datamodule.plot(sample) except RGBBandsMissingError: pass if fig: summary_writer = self.logger.experiment summary_writer.add_figure( f'image/{batch_idx}', fig, global_step=self.global_step ) plt.close()
[docs] def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test loss and additional metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ x = batch['image'] batch_size = x.shape[0] # TODO: remove .to(...) once we have a real pixelwise regression dataset y = batch[self.target_key].to(torch.float) y_hat = self(x) if y_hat.ndim != y.ndim: y = y.unsqueeze(dim=1) loss = self.criterion(y_hat, y) self.log('test_loss', loss, batch_size=batch_size) self.test_metrics(y_hat, y) self.log_dict(self.test_metrics, batch_size=batch_size)
[docs] def predict_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the predicted regression values. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. Returns: Output predicted probabilities. """ x = batch['image'] y_hat: Tensor = self(x) return y_hat
[docs]class PixelwiseRegressionTask(RegressionTask): """LightningModule for pixelwise regression of images. .. versionadded:: 0.5 """ target_key = 'mask'
[docs] def configure_models(self) -> None: """Initialize the model.""" weights = self.weights model = self.hparams['model'] backbone = self.hparams['backbone'] in_channels = self.hparams['in_channels'] match model: case 'unet': self.model = smp.Unet( encoder_name=backbone, encoder_weights='imagenet' if weights is True else None, in_channels=in_channels, classes=1, ) case 'deeplabv3+': self.model = smp.DeepLabV3Plus( encoder_name=backbone, encoder_weights='imagenet' if weights is True else None, in_channels=in_channels, classes=1, ) case 'fcn': self.model = FCN( in_channels=in_channels, classes=1, num_filters=self.hparams['num_filters'], ) case 'upernet': self.model = smp.UPerNet( encoder_name=backbone, encoder_weights='imagenet' if weights is True else None, in_channels=in_channels, classes=1, ) case 'segformer': self.model = smp.Segformer( encoder_name=backbone, encoder_weights='imagenet' if weights is True else None, in_channels=in_channels, classes=1, ) case 'dpt': self.model = smp.DPT( encoder_name=backbone, encoder_weights='imagenet' if weights is True else None, in_channels=in_channels, classes=1, ) if model != 'fcn': if weights and weights is not True: if isinstance(weights, WeightsEnum): state_dict = weights.get_state_dict(progress=True) elif os.path.exists(weights): _, state_dict = utils.extract_backbone(weights) else: state_dict = get_weight(weights).get_state_dict(progress=True) self.model.encoder.load_state_dict(state_dict) # Freeze backbone if self.hparams.get('freeze_backbone', False) and model != 'fcn': for param in self.model.encoder.parameters(): param.requires_grad = False # Freeze decoder if self.hparams.get('freeze_decoder', False) and model != 'fcn': for param in self.model.decoder.parameters(): param.requires_grad = False

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources