[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
Contribute a New DataModule¶
Written by: Nils Lehmann
TorchGeo provides Lightning DataModules
and trainers to facilitate easy and scalable model training based on simple configuration files. Essentially, a DataModule
implements the logic for splitting a dataset into train, validation and test splits for reproducibility, wrapping them in PyTorch DataLoaders
and apply augmentations to batches of data. This tutorial will outline a guide to adding a new datamodule to TorchGeo. It is often easy to do so alongside a new dataset and will make
the dataset directly usable for a Lightning training and evaluation pipeline
Adding the datamodule¶
Adding a datamodule to TorchGeo consists of roughly four parts:
a
dataset_name.py
file undertorchgeo/datamodules
that implements the split logic and defines augmentationa
dataset_name.yaml
file undertests/configs
that defines arguments to directly test the datamodule with the appropriate taskadd the above yaml file to the list of files to be tested in the corresponding
test_{task}.py
file undertests/trainers
an entry to the documentation page file
datamodules.rst
underdocs/api/
The datamodule dataset_name.py
file¶
The vast majority of new DataModules can inherit from one of the base classes that take care of the majority of the work. The goal of the dataset specific DataModule is to specify how the dataset should be split into train/val/test and any augmentations that should be applied to batches of data.
"""NewDatasetDataModule datamodule."""
import os
from typing import Any
import kornia.augmentation as K
import torch
from torch.utils.data import Subset
from .geo import NonGeoDataModule
from .utils import group_shuffle_split
# We follow the convention of appending the dataset_name with "DataModule"
class NewDatasetDataModule(NonGeoDataModule):
"""LightningDataModule implementation for the NewDataset dataset.
Make a comment here about how the dataset is split into train/val/test.
You can also add any other comments or references that are helpful to
understand implementation decisions
.. versionadded:: for example 0.7
"""
# you can define channelwise normalization statistics that will be applied
# to data batches, which is usually crucial for training stability and decent performance
mean = torch.Tensor([0.5, 0.4, 0.3])
std = torch.Tensor([1.5, 1.4, 1.3])
def __init__(
self, batch_size: int = 64, num_workers: int = 0, size: int = 256, **kwargs: Any
) -> None:
"""Initialize a new NewDatasetModule instance.
Args:
batch_size: Size of each mini-batch.
num_workers: Number of workers for parallel data loading.
size: resize images of input size 1000x1000 to size x size
**kwargs: Additional keyword arguments passed to
:class:`~torchgeo.datasets.NewDataset`.
"""
# in the init method of the base class the dataset will be instantiated with **kwargs
super().__init__(NewDatasetName, batch_size, num_workers, **kwargs)
# you can specify a series of Kornia augmentations that will be
# applied to a batch of training data in `on_after_batch_transfer` in the NonGeoDataModule base class
self.train_aug = K.AugmentationSequential(
K.Resize((size, size)),
K.Normalize(self.mean, self.std),
K.RandomHorizontalFlip(p=0.5),
K.RandomVerticalFlip(p=0.5),
data_keys=None,
keepdim=True,
)
# you can also define specific augmentations for other experiment phases, if not specified
# self.aug Augmentations will be applied
self.aug = K.AugmentationSequential(
K.Normalize(self.mean, self.std),
K.Resize((size, size)), data_keys=None, keepdim=True
)
self.size = size
# setup defines how the dataset should be split
# this could either be predefined from the dataset authors or
# done in a prescribed way if some or no splits are specified
def setup(self, stage: str) -> None:
"""Set up datasets.
Args:
stage: Either 'fit', 'validate', 'test', or 'predict'.
"""
if stage in ['fit', 'validate']:
dataset = NewDatasetName(split='train', **self.kwargs)
# perhaps the dataset contains some geographical metadata based on which you would create reproducible random
# splits
grouping_paths = [os.path.dirname(path) for path in dataset.file_list]
train_indices, val_indices = group_shuffle_split(
grouping_paths, test_size=0.2, random_state=0
)
self.train_dataset = Subset(dataset, train_indices)
self.val_dataset = Subset(dataset, val_indices)
if stage in ['test']:
self.test_dataset = NewDatasetName(split='test', **self.kwargs)
Linters¶
See the linter docs for an overview of linters that TorchGeo employs and how to apply them during commits for example.
Unit tests¶
TorchGeo maintains a test coverage of 100%. This means, that every line of code written within the torchgeo directory is being called by some unit test. For new datasets, we commonly write a separate test file, however, for datamodules we would like to test them directly with one of the task trainers. To do this, you simply need to define a config.yaml
file and add it to the list of files to be tested by a task. For example, if you added a new datamodule for image segmentation you would
write a config file that should look something like this:
model:
class_path: SemanticSegmentationTask
init_args:
loss: 'ce'
model: 'unet'
backbone: 'resnet18'
in_channels: 3 # number of input channels for the dataset
num_classes: 7 # number of segmentation models
num_filters: 1 # a smaller model version for faster unit tests
ignore_index: null # one can ignore certain classes during the loss computation
data:
class_path: NewDatasetNameDataModule # arguments to the DataModule above you wrote
init_args:
batch_size: 1 #
dict_kwargs:
root: 'tests/data/deepglobelandcover' # necessary arguments for the underlying dataset class that the datamodule builds on
The yaml file should “simulate” how you would use this datamodule for an actual experiment. Add this file with dataset_name.yaml
to the tests/conf
directory.
Final Checklist¶
This final checklist might provide a useful overview of the individual parts discussed in this tutorial. You definitely do not need to check all boxes, before submitting a PR. If you have any questions feel free to ask in the Slack channel or open a PR already such that maintainers or other community members can answer specific questions or give pointers. If you want to run your PR as a work of progress, such that the CI tests are run against your code while you work on ticking more boxes you can also convert the PR to a draft on the right side.
The datamodule implementation
define training/val/test split
if there are dataset specific augmentations, implement and reference them
add copyright notice to top of the file
The config test file
select the appropriate task, if the dataset supports multiple ones, you can create one for each task
correct arguments such as the number of targets (classes)
add the config file to the list of files to be tested in the corresponding
test_{task}.py
file undertests/trainers
Unit Tests
100% test coverage
Documentation
an entry to the documentation page file
datamodules.rst
underdocs/api/