Shortcuts

Source code for torchgeo.trainers.instance_segmentation

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""Trainers for instance segmentation."""

from typing import Any

import kornia.augmentation as K
import matplotlib.pyplot as plt
import torch
from matplotlib.figure import Figure
from timm.models import adapt_input_conv
from torch import Tensor
from torch.nn.parameter import Parameter
from torchmetrics import MetricCollection
from torchmetrics.detection.mean_ap import MeanAveragePrecision
from torchvision.models import ResNet50_Weights
from torchvision.models.detection import (
    MaskRCNN_ResNet50_FPN_Weights,
    maskrcnn_resnet50_fpn,
)

from ..datasets import RGBBandsMissingError, unbind_samples
from .base import BaseTask
from .utils import GeneralizedRCNNTransformNoOp


[docs]class InstanceSegmentationTask(BaseTask): """Instance Segmentation. .. versionadded:: 0.7 """ ignore = None monitor = 'val_segm_map' mode = 'max'
[docs] def __init__( self, model: str = 'mask-rcnn', backbone: str = 'resnet50', weights: bool | None = None, in_channels: int = 3, num_classes: int = 91, lr: float = 1e-3, patience: int = 10, freeze_backbone: bool = False, ) -> None: """Initialize a new InstanceSegmentationTask instance. Note that we disable the internal normalize+resize transform of the MaskRCNN model. Please ensure your images are appropriately resized before passing them to the model. Args: model: Name of the model to use. backbone: Name of the backbone to use. weights: Initial model weights. True for ImageNet weights, False or None for random weights. in_channels: Number of input channels to model. num_classes: Number of prediction classes (including the background). lr: Learning rate for optimizer. patience: Patience for learning rate scheduler. freeze_backbone: Freeze the backbone network to fine-tune the decoder and segmentation head. """ super().__init__()
[docs] def configure_models(self) -> None: """Initialize the model. Raises: ValueError: If *model* or *backbone* are invalid. """ model: str = self.hparams['model'] backbone: str = self.hparams['backbone'] in_channels: int = self.hparams['in_channels'] num_classes: int = self.hparams['num_classes'] weights = None weights_backbone = None if self.hparams['weights']: weights_backbone = ResNet50_Weights.IMAGENET1K_V1 # TODO: drop last layer of weights if num_classes == 91: weights = MaskRCNN_ResNet50_FPN_Weights.COCO_V1 # Create model if model == 'mask-rcnn': if backbone == 'resnet50': self.model = maskrcnn_resnet50_fpn( weights=weights, num_classes=num_classes, weights_backbone=weights_backbone, ) self.model.transform = GeneralizedRCNNTransformNoOp() else: msg = f"Invalid backbone type '{backbone}'. Supported backbone: 'resnet50'" raise ValueError(msg) else: msg = f"Invalid model type '{model}'. Supported model: 'mask-rcnn'" raise ValueError(msg) weight = adapt_input_conv(in_channels, self.model.backbone.body.conv1.weight) self.model.backbone.body.conv1.weight = Parameter(weight) self.model.backbone.body.conv1.in_channels = in_channels # Freeze backbone if self.hparams['freeze_backbone']: for param in self.model.backbone.parameters(): param.requires_grad = False
[docs] def configure_metrics(self) -> None: """Initialize the performance metrics. * :class:`~torchmetrics.detection.mean_ap.MeanAveragePrecision`: Mean average precision (mAP) and mean average recall (mAR). Precision is the number of true positives divided by the number of true positives + false positives. Recall is the number of true positives divived by the number of true positives + false negatives. Uses 'macro' averaging. Higher values are better. .. note:: * 'Micro' averaging suits overall performance evaluation but may not reflect minority class accuracy. * 'Macro' averaging gives equal weight to each class, and is useful for balanced performance assessment across imbalanced classes. """ metrics = MetricCollection([MeanAveragePrecision(iou_type=('bbox', 'segm'))]) self.val_metrics = metrics.clone(prefix='val_') self.test_metrics = metrics.clone(prefix='test_')
[docs] def training_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> Tensor: """Compute the training loss. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. Returns: The loss tensor. """ x = batch['image'] y = { 'boxes': batch['bbox_xyxy'], 'labels': batch['label'], 'masks': batch['mask'], } loss_dict = self(x.unbind(), unbind_samples(y)) self.log_dict(loss_dict, batch_size=len(x)) loss: Tensor = sum(loss_dict.values()) return loss
[docs] def validation_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> None: """Compute the validation metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ x = batch['image'] y = { 'boxes': batch['bbox_xyxy'], 'labels': batch['label'], 'masks': batch['mask'], } y_hat = self(x.unbind()) for pred in y_hat: pred['masks'] = (pred['masks'] > 0.5).squeeze(1).to(torch.uint8) metrics = self.val_metrics(y_hat, unbind_samples(y)) # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714 metrics.pop('val_classes', None) self.log_dict(metrics, batch_size=len(x)) if ( batch_idx < 10 and hasattr(self.trainer, 'datamodule') and hasattr(self.trainer.datamodule, 'plot') and self.logger and hasattr(self.logger, 'experiment') and hasattr(self.logger.experiment, 'add_figure') ): datamodule = self.trainer.datamodule aug = K.AugmentationSequential( K.Denormalize(datamodule.mean, datamodule.std), data_keys=None, keepdim=True, ) batch = aug(batch) batch['prediction_bbox_xyxy'] = [pred['boxes'].cpu() for pred in y_hat] batch['prediction_mask'] = [pred['masks'].cpu() for pred in y_hat] batch['prediction_label'] = [pred['labels'].cpu() for pred in y_hat] batch['prediction_score'] = [pred['scores'].cpu() for pred in y_hat] batch['image'] = batch['image'].cpu() sample = unbind_samples(batch)[0] fig: Figure | None = None try: fig = datamodule.plot(sample) except RGBBandsMissingError: pass if fig: summary_writer = self.logger.experiment summary_writer.add_figure( f'image/{batch_idx}', fig, global_step=self.global_step ) plt.close()
[docs] def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None: """Compute the test metrics. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. """ x = batch['image'] y = { 'boxes': batch['bbox_xyxy'], 'labels': batch['label'], 'masks': batch['mask'], } y_hat = self(x.unbind()) for pred in y_hat: pred['masks'] = (pred['masks'] > 0.5).squeeze(1).to(torch.uint8) metrics = self.test_metrics(y_hat, unbind_samples(y)) # https://github.com/Lightning-AI/torchmetrics/pull/1832#issuecomment-1623890714 metrics.pop('test_classes', None) self.log_dict(metrics, batch_size=len(x))
[docs] def predict_step( self, batch: Any, batch_idx: int, dataloader_idx: int = 0 ) -> list[dict[str, Tensor]]: """Compute the predicted masks. Args: batch: The output of your DataLoader. batch_idx: Integer displaying index of this batch. dataloader_idx: Index of the current dataloader. Returns: Output predicted masks. """ x = batch['image'] y_hat: list[dict[str, Tensor]] = self(x.unbind()) for pred in y_hat: keep = pred['scores'] > 0.05 pred['boxes'] = pred['boxes'][keep] pred['labels'] = pred['labels'][keep] pred['scores'] = pred['scores'][keep] pred['masks'] = (pred['masks'] > 0.5).squeeze(1).to(torch.uint8)[keep] return y_hat

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources