Shortcuts
Open in Studio Open in Colab
[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

Contribute a New Non-Geospatial Dataset

Written by: Nils Lehmann

Open-source datasets have significantly accelerated machine learning research. Geospatial machine learning datasets can be particularly complex to work with compared to more standard RGB-based vision datasets. To spare the community from having to repeatedly implement data loading logic over and over, TorchGeo provides dozens of built-in datasets such that they can be downloaded and ready for use in a PyTorch framework with a single line of code. This tutorial will show how you can add a new non-geospatial dataset to this growing collection.

As a reminder, TorchGeo differentiates between two types of datasets: geospatial and non-geospatial datasets. Non-geospatial datasets are integer indexed, like the datasets one might be familiar with from torchvision, while geospatial datasets are indexed via spatiotemporal bounding boxes. Non-geospatial datasets can still return geospatial and other metadata and should be specific to the remote sensing domain.

Setup

First, we install TorchGeo and its dependencies.

[ ]:
%pip install torchgeo

Where to start

There are many types of remote sensing datasets. Satellite-Image-Deep-Learning maintains a list of many of these datasets, as well as links to other similar curated lists.

Two aspects that will make it a lot easier to add the dataset are whether or not the dataset can be easily downloaded and whether or the dataset comes with a Github repository and publication that outlines how the authors intend the dataset to be used. These are not necessary criteria, and sometimes it might be even more worthwhile to add a dataset without an existing code base, precisely because the marginal contribution to the community might be greater since a use of the dataset does not necessitate writing the loading implementation from scratch.

Adding the dataset

Once you have identified a dataset that you would like to add to TorchGeo, you could identify in what application category it might roughly fall in. For example, a segmentation dataset based on a collection of .png files, versus a classification dataset based on pre-defined image chips in .tif files. In the later case, if you find that the dataset contains .tif files that have very large pixel sizes, such that loading a single file might be costly, consider adding the dataset as a geospatial dataset for easier indexing. Once, you have identified the “task” such as segmentation vs classification and the dataset format, see whether a dataset of the same or similar category exists in TorchGeo already. All datasets inherit from a NonGeoDataset or GeoDataset base class that provides an outline for the implementation logic as well as additional utility functions that should be reused. This reduces code duplication and makes it easier to unit test datasets.

Adding a dataset to TorchGeo consists of roughly four steps:

  1. a dataset_name.py file itself that implements the logic of the dataset

  2. a data.py file that creates dummy data in the same structure and format as the original dataset for unit tests

  3. a test_dataset_name.py file that implements unit tests for the dataset

  4. an entry to the documentation page files: non_geo_datasets.csv and datasets.rst

The dataset_name.py file

This file implements the logic to load a sample from the dataset as well as downloading the dataset automatically if possible. The new dataset inherits from a base class and the documentation string (docstring) of the class should contain:

  • a short summary of the dataset

  • outline the features, such as the task the dataset is designed to solve

  • outline the format the dataset comes in, e.g., file types, pixel dimensions, etc.

  • a proper reference to the dataset such as a link to the paper so users can adequately cite the dataset when using it

  • if required, a note about additional dependencies that are not part of TorchGeo’s required dependencies

The dataset implementation itself should contain:

  • a method to create an index structure the dataset can iterate over to load samples. This index structure also defines the length (__len__) of the dataset, i.e. how many individual samples can be loaded from the dataset

  • a __getitem__ method that takes an integer index argument, loads a sample of the dataset, and returns its components in a dictionary

  • a _verify method that checks whether the dataset can be found on the filesystem, has already been downloaded and only needs to be extracted, or downloads and extracts the dataset from the web

  • a plot method that can visually display a single sample of the dataset

The code below attempts to roughly outline the parts required for a new NonGeoDataset. Specifics are of course very dependent on the type of dataset you want to add, but this template and other existing datasets should give you a decent starting point.

[ ]:
from collections.abc import Callable

from matplotlib.pyplot import Figure
from torch import Tensor

from torchgeo.datasets import NonGeoDataset
from torchgeo.datasets.utils import Path


class MyNewDataset(NonGeoDataset):
    """MyNewDataset.

    Short summary of the dataset and link to its homepage.

    Dataset features:

    * number of classes
    * sensors
    * area covered
    * etc.

    Dataset format:

    * what file format and shape the input data comes in
    * what file format and shape the target data comes in
    * possible metadata files

    If you use this dataset in your research, please cite the following paper:

    * URL of publication or citation information

    .. versionadded:: next TorchGeo minor release version, e.g., 1.0
    """

    # In this part of the code you can define class attributes such as a list of
    # class names, color maps, url and checksums for data download, and other
    # attributes that one might require repeatedly in the subsequent class methods.

    def __init__(
        self,
        root: Path = 'data',
        split: str = 'train',
        transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
        download: bool = False,
    ) -> None:
        """Initialize the dataset.

        The init parameters can include additional arguments, such as an option to
        select specific image bands, data modalities, or other arguments that give
        greater control over data loading. They should all have reasonable defaults.

        Args:
            root: root directory where dataset can be found
            split: one of "train", "val", or "test"
            transforms: a function/transform that takes input sample and its target as
                entry and returns a transformed version
            download: if True, download dataset and store it in the root directory
        """

    def __len__(self) -> int:
        """The length of the dataset.

        This is the total number of samples per epoch, and is used to define the
        maximum allow index that can be passed to `__getitem__`.
        """

    def __getitem__(self, index: int) -> dict[str, Tensor]:
        """A single sample from the dataset.

        Load a single input image and target label or mask, and return it in a
        dictionary.
        """

    def plot(self) -> Figure:
        """Plot a sample of the dataset for visualization purposes.

        This might involve selecting the RGB bands, using a colormap to display a mask,
        adding a legend with class labels, etc.
        """

The data.py file

The data.py file is placed under tests/data/dataset_name/ directory and creates a smaller dummy dataset that replicates the features and formats of the actual full datasets for unit tests. This is needed to keep the tests fast (we don’t have time or storage space to download the real dataset) and to comply with the dataset license.

The script should:

  • replicate the directory structure and file names

  • replicate the file format, data type, and range of values

  • use the same compression scheme to simulate downloading the dataset

This is usually highly dependent on the dataset format and structure the new dataset comes in. You should always look for a similar dataset first and use that as a reference. However, below is an outline of the usual building blocks of a data.py script, for example an image segmentation dataset with 10 classes.

[ ]:
import os
import shutil
import tempfile

import numpy as np
from PIL import Image

# Define the root directory and subdirectories
# Normally this would be the current directory (tests/data/my_new_dataset)
root_dir = os.path.join(tempfile.gettempdir(), 'my_new_dataset')
sub_dirs = ['image', 'target']
splits = ['train', 'val', 'test']

image_file_names = ['sample_1.png', 'sample_2.png', 'sample_3.png']

IMG_SIZE = 32


# Function to create dummy input images
def create_input_image(path: str, shape: tuple[int], pixel_values: list[int]) -> None:
    data = np.random.choice(pixel_values, size=shape, replace=True).astype(np.uint8)
    img = Image.fromarray(data)
    img.save(path)


# Function to create dummy targets
def create_target_images(split: str, filename: str) -> None:
    target_pixel_values = range(10)
    path = os.path.join(root_dir, 'target', split, filename)
    create_input_image(path, (IMG_SIZE, IMG_SIZE), target_pixel_values)


# Create a new clean version when re-running the script
if os.path.exists(root_dir):
    shutil.rmtree(root_dir)

# Create the directory structure
for sub_dir in sub_dirs:
    for split in splits:
        os.makedirs(os.path.join(root_dir, sub_dir, split), exist_ok=True)

# Create dummy data for all splits and filenames
for split in splits:
    for filename in image_file_names:
        create_input_image(
            os.path.join(root_dir, 'image', split, filename),
            (IMG_SIZE, IMG_SIZE),
            range(2**16),
        )
        create_target_images(split, filename.replace('_', '_target_'))

# Zip directory
shutil.make_archive(root_dir, 'zip', '.', root_dir)

The test_dataset_name.py file

The test_dataset_name.py file is placed under the tests/datasets/ directory. This file implements the unit tests for the dataset, such that every line of code in dataset_name.py is tested. The logic of the individual test cases will likely be very similar to existing test files so you can look at those to to see how you can test the individual parts of the dataset logic.

[ ]:
import shutil
from pathlib import Path

import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from matplotlib import pyplot as plt
from pytest import MonkeyPatch

from torchgeo.datasets import DatasetNotFoundError


def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
    shutil.copy(url, root)


class TestMyNewDataset:
    # pytest fixtures can be used to define variables to test different argument
    # configurations to test, for example the different splits of the dataset
    # or subselection of modalities/bands
    @pytest.fixture(params=['train', 'val', 'test'])
    def dataset(
        self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
    ) -> MyNewDataset:
        # monkeypatch can overwrite the class attributes defined above the __init__
        # method and use the specific unit tests settings to mock behavior

        split: str = request.param
        transforms = nn.Identity()
        return MyNewDataset(tmp_path, split=split, transforms=transforms, download=True)

    def test_getitem(self, dataset: MyNewDataset) -> None:
        # Retrieve a sample and check some of the desired properties
        x = dataset[0]
        assert isinstance(x, dict)
        assert isinstance(x['image'], torch.Tensor)
        assert isinstance(x['label'], torch.Tensor)

    # For all additional class arguments, check behavior for invalid parameters
    def test_invalid_split(self) -> None:
        with pytest.raises(AssertionError):
            MyNewDataset(foo='bar')

    # Test the length of the dataset, this should coincide with the dummy data
    def test_len(self, dataset: MyNewDataset) -> None:
        assert len(dataset) == 2

    # Test the logic when the dataset is already downloaded
    def test_already_downloaded(self, dataset: MyNewDataset, tmp_path: Path) -> None:
        MyNewDataset(root=tmp_path, download=True)

    # Test the logic when the dataset is already downloaded but not extracted
    def test_already_downloaded_not_extracted(
        self, dataset: MyNewDataset, tmp_path: Path
    ) -> None:
        shutil.rmtree(dataset.root)
        download_url(dataset.url, root=tmp_path)
        MyNewDataset(root=tmp_path, download=False)

    # Test the logic when the dataset is not downloaded
    def test_not_downloaded(self, tmp_path: Path) -> None:
        with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
            MyNewDataset(tmp_path)

    # Test the plotting method through something like the following
    def test_plot(self, dataset: MyNewDataset) -> None:
        x = dataset[0].copy()
        x['prediction'] = x['label'].clone()
        dataset.plot(x, suptitle='Test')
        plt.close()

Documentation Entries

The entry point for new and experienced users of domain libraries is often the dedicated documentation page that accompanies a Github repository. TorchGeo uses the popular Sphinx framework to build its documentation. To display the documentation strings you have written in dataset_name.py on the actual documentation page, you need to create an entry in docs/api/datasets.rst in alphabetical order:

Dataset Name
^^^^^^^^^^^^

.. autoclass:: MyNewDataset

Additionally, add a row in the non_geo_datasets.csv file under docs/api/datasets to include the dataset in the overview table.

Linters

See the linter docs for an overview of linters that TorchGeo employs and how to apply them during commits.

Test Coverage

TorchGeo maintains a test coverage of 100%. This means, that every line of code written within the torchgeo directory is being run by some unit test. The testing docs provide instructions on how you can test the coverage locally for the dataset_new.py file that you are adding.

Final Checklist

This final checklist might provide a useful overview of the individual parts discussed in this tutorial. You definitely do not need to check all boxes, before submitting a PR. If you have any questions feel free to ask on Slack or open a PR already such that maintainers or other community members can answer specific questions or give pointers. If you want to run your PR as a work of progress, such that the CI tests are run against your code while you work on ticking more boxes you can also convert the PR to a draft on the right side menu.

  • Dataset implementation in dataset_name.py

    • Class docstring containing:

      • Summary intro

      • Dataset features

      • Dataset format

      • Link to publication

      • versionadded tag

      • if applicable a note on additional dependencies

    • all class methods have docstrings

    • all class methods have argument and return type hints, mypy (the tool that checks type hints) can be confusing at the beginning so don’t hesitate to ask for help

    • if dataset is on GitHub or Huggingface, url link should contain the commit hash

    • checksum added

    • plot method that can display a single sample from the dataset (you can add the resulting figure in your PR description)

    • add the dataset to torchgeo/datastes/__init__.py

    • Add the copyright at the top of the file

  • Dummy data script data.py

    • replicate directory structure

    • replicate naming of directory and files

    • for image based datasets, use a small size, like 32x32

  • Unit tests test_dataset_name.py

    • 100% test coverage

  • Documentation with non_geo_datasets.csv and datasets.rst

    • entry in datasets.rst

    • entry in non_geo_datasets.csv

    • documentation displays properly, this can be checked locally or via the GitHub CI tests under docs/readthedocs.org:torchgeo

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources