Shortcuts
Open in Studio Open in Colab
[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

Contribute a New DataModule

Written by: Nils Lehmann

TorchGeo provides Lightning DataModules and trainers to facilitate easy and scalable model training based on simple configuration files. Essentially, a DataModule implements the logic for splitting a dataset into train, validation and test splits for reproducibility, wrapping them in PyTorch DataLoaders and apply augmentations to batches of data. This tutorial will outline a guide to adding a new datamodule to TorchGeo. It is often easy to do so alongside a new dataset and will make the dataset directly usable for a Lightning training and evaluation pipeline

Adding the datamodule

Adding a datamodule to TorchGeo consists of roughly four parts:

  1. a dataset_name.py file under torchgeo/datamodules that implements the split logic and defines augmentation

  2. a dataset_name.yaml file under tests/configs that defines arguments to directly test the datamodule with the appropriate task

  3. add the above yaml file to the list of files to be tested in the corresponding test_{task}.py file under tests/trainers

  4. an entry to the documentation page file datamodules.rst under docs/api/

The datamodule dataset_name.py file

The vast majority of new DataModules can inherit from one of the base classes that take care of the majority of the work. The goal of the dataset specific DataModule is to specify how the dataset should be split into train/val/test and any augmentations that should be applied to batches of data.

"""NewDatasetDataModule datamodule."""

import os
from typing import Any

import kornia.augmentation as K
import torch
from torch.utils.data import Subset

from .geo import NonGeoDataModule
from .utils import group_shuffle_split


# We follow the convention of appending the dataset_name with "DataModule"
class NewDatasetDataModule(NonGeoDataModule):
    """LightningDataModule implementation for the NewDataset dataset.

    Make a comment here about how the dataset is split into train/val/test.

    You can also add any other comments or references that are helpful to
    understand implementation decisions

    .. versionadded:: for example 0.7
    """
    # you can define channelwise normalization statistics that will be applied
    # to data batches, which is usually crucial for training stability and decent performance
    mean = torch.Tensor([0.5, 0.4, 0.3])
    std = torch.Tensor([1.5, 1.4, 1.3])

    def __init__(
        self, batch_size: int = 64, num_workers: int = 0, size: int = 256, **kwargs: Any
    ) -> None:
        """Initialize a new NewDatasetModule instance.

        Args:
            batch_size: Size of each mini-batch.
            num_workers: Number of workers for parallel data loading.
            size: resize images of input size 1000x1000 to size x size
            **kwargs: Additional keyword arguments passed to
                :class:`~torchgeo.datasets.NewDataset`.
        """
        # in the init method of the base class the dataset will be instantiated with **kwargs
        super().__init__(NewDatasetName, batch_size, num_workers, **kwargs)

        # you can specify a series of Kornia augmentations that will be
        # applied to a batch of training data in `on_after_batch_transfer` in the NonGeoDataModule base class
        self.train_aug = K.AugmentationSequential(
            K.Resize((size, size)),
            K.Normalize(self.mean, self.std),
            K.RandomHorizontalFlip(p=0.5),
            K.RandomVerticalFlip(p=0.5),
            data_keys=None,
            keepdim=True,
        )

        # you can also define specific augmentations for other experiment phases, if not specified
        # self.aug Augmentations will be applied
        self.aug = K.AugmentationSequential(
            K.Normalize(self.mean, self.std),
            K.Resize((size, size)), data_keys=None, keepdim=True
        )

        self.size = size

    # setup defines how the dataset should be split
    # this could either be predefined from the dataset authors or
    # done in a prescribed way if some or no splits are specified
    def setup(self, stage: str) -> None:
        """Set up datasets.

        Args:
            stage: Either 'fit', 'validate', 'test', or 'predict'.
        """
        if stage in ['fit', 'validate']:
            dataset = NewDatasetName(split='train', **self.kwargs)
            # perhaps the dataset contains some geographical metadata based on which you would create reproducible random
            # splits
            grouping_paths = [os.path.dirname(path) for path in dataset.file_list]
            train_indices, val_indices = group_shuffle_split(
                grouping_paths, test_size=0.2, random_state=0
            )
            self.train_dataset = Subset(dataset, train_indices)
            self.val_dataset = Subset(dataset, val_indices)
        if stage in ['test']:
            self.test_dataset = NewDatasetName(split='test', **self.kwargs)

Linters

See the linter docs for an overview of linters that TorchGeo employs and how to apply them during commits for example.

Unit tests

TorchGeo maintains a test coverage of 100%. This means, that every line of code written within the torchgeo directory is being called by some unit test. For new datasets, we commonly write a separate test file, however, for datamodules we would like to test them directly with one of the task trainers. To do this, you simply need to define a config.yaml file and add it to the list of files to be tested by a task. For example, if you added a new datamodule for image segmentation you would write a config file that should look something like this:

model:
  class_path: SemanticSegmentationTask
  init_args:
    loss: 'ce'
    model: 'unet'
    backbone: 'resnet18'
    in_channels: 3 # number of input channels for the dataset
    num_classes: 7 # number of segmentation models
    num_filters: 1 # a smaller model version for faster unit tests
    ignore_index: null # one can ignore certain classes during the loss computation
data:
  class_path: NewDatasetNameDataModule # arguments to the DataModule above you wrote
  init_args:
    batch_size: 1 #
  dict_kwargs:
    root: 'tests/data/deepglobelandcover' # necessary arguments for the underlying dataset class that the datamodule builds on

The yaml file should “simulate” how you would use this datamodule for an actual experiment. Add this file with dataset_name.yaml to the tests/conf directory.

Final Checklist

This final checklist might provide a useful overview of the individual parts discussed in this tutorial. You definitely do not need to check all boxes, before submitting a PR. If you have any questions feel free to ask in the Slack channel or open a PR already such that maintainers or other community members can answer specific questions or give pointers. If you want to run your PR as a work of progress, such that the CI tests are run against your code while you work on ticking more boxes you can also convert the PR to a draft on the right side.

  • The datamodule implementation

    • define training/val/test split

    • if there are dataset specific augmentations, implement and reference them

    • add copyright notice to top of the file

  • The config test file

    • select the appropriate task, if the dataset supports multiple ones, you can create one for each task

    • correct arguments such as the number of targets (classes)

    • add the config file to the list of files to be tested in the corresponding test_{task}.py file under tests/trainers

  • Unit Tests

    • 100% test coverage

  • Documentation

    • an entry to the documentation page file datamodules.rst under docs/api/

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources