[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
Contribute a New Non-Geospatial Dataset¶
Written by: Nils Lehmann
Open-source datasets have significantly accelerated machine learning research. Geospatial machine learning datasets can be particularly complex to work with compared to more standard RGB-based vision datasets. To spare the community from having to repeatedly implement data loading logic over and over, TorchGeo provides dozens of built-in datasets such that they can be downloaded and ready for use in a PyTorch framework with a single line of code. This tutorial will show how you can add a new non-geospatial dataset to this growing collection.
As a reminder, TorchGeo differentiates between two types of datasets: geospatial and non-geospatial datasets. Non-geospatial datasets are integer indexed, like the datasets one might be familiar with from torchvision, while geospatial datasets are indexed via spatiotemporal bounding boxes. Non-geospatial datasets can still return geospatial and other metadata and should be specific to the remote sensing domain.
Setup¶
First, we install TorchGeo and its dependencies.
[ ]:
%pip install torchgeo
Where to start¶
There are many types of remote sensing datasets. Satellite-Image-Deep-Learning maintains a list of many of these datasets, as well as links to other similar curated lists.
Two aspects that will make it a lot easier to add the dataset are whether or not the dataset can be easily downloaded and whether or the dataset comes with a Github repository and publication that outlines how the authors intend the dataset to be used. These are not necessary criteria, and sometimes it might be even more worthwhile to add a dataset without an existing code base, precisely because the marginal contribution to the community might be greater since a use of the dataset does not necessitate writing the loading implementation from scratch.
Adding the dataset¶
Once you have identified a dataset that you would like to add to TorchGeo, you could identify in what application category it might roughly fall in. For example, a segmentation dataset based on a collection of .png files, versus a classification dataset based on pre-defined image chips in .tif files. In the later case, if you find that the dataset contains .tif files that have very large pixel sizes, such that loading a single file might be costly, consider adding the dataset as a
geospatial dataset for easier indexing. Once, you have identified the “task” such as segmentation vs classification and the dataset format, see whether a dataset of the same or similar category exists in TorchGeo already. All datasets inherit from a NonGeoDataset
or GeoDataset
base class that provides an outline for the implementation logic as well as additional utility functions that should be reused. This reduces code duplication and makes it easier to unit test datasets.
Adding a dataset to TorchGeo consists of roughly four steps:
a
dataset_name.py
file itself that implements the logic of the dataseta
data.py
file that creates dummy data in the same structure and format as the original dataset for unit testsa
test_dataset_name.py
file that implements unit tests for the datasetan entry to the documentation page files:
non_geo_datasets.csv
anddatasets.rst
The dataset_name.py
file¶
This file implements the logic to load a sample from the dataset as well as downloading the dataset automatically if possible. The new dataset inherits from a base class and the documentation string (docstring) of the class should contain:
a short summary of the dataset
outline the features, such as the task the dataset is designed to solve
outline the format the dataset comes in, e.g., file types, pixel dimensions, etc.
a proper reference to the dataset such as a link to the paper so users can adequately cite the dataset when using it
if required, a note about additional dependencies that are not part of TorchGeo’s required dependencies
The dataset implementation itself should contain:
a method to create an index structure the dataset can iterate over to load samples. This index structure also defines the length (
__len__
) of the dataset, i.e. how many individual samples can be loaded from the dataseta
__getitem__
method that takes an integer index argument, loads a sample of the dataset, and returns its components in a dictionarya
_verify
method that checks whether the dataset can be found on the filesystem, has already been downloaded and only needs to be extracted, or downloads and extracts the dataset from the weba
plot
method that can visually display a single sample of the dataset
The code below attempts to roughly outline the parts required for a new NonGeoDataset
. Specifics are of course very dependent on the type of dataset you want to add, but this template and other existing datasets should give you a decent starting point.
[ ]:
from collections.abc import Callable
from matplotlib.pyplot import Figure
from torch import Tensor
from torchgeo.datasets import NonGeoDataset
from torchgeo.datasets.utils import Path
class MyNewDataset(NonGeoDataset):
"""MyNewDataset.
Short summary of the dataset and link to its homepage.
Dataset features:
* number of classes
* sensors
* area covered
* etc.
Dataset format:
* what file format and shape the input data comes in
* what file format and shape the target data comes in
* possible metadata files
If you use this dataset in your research, please cite the following paper:
* URL of publication or citation information
.. versionadded:: next TorchGeo minor release version, e.g., 1.0
"""
# In this part of the code you can define class attributes such as a list of
# class names, color maps, url and checksums for data download, and other
# attributes that one might require repeatedly in the subsequent class methods.
def __init__(
self,
root: Path = 'data',
split: str = 'train',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
) -> None:
"""Initialize the dataset.
The init parameters can include additional arguments, such as an option to
select specific image bands, data modalities, or other arguments that give
greater control over data loading. They should all have reasonable defaults.
Args:
root: root directory where dataset can be found
split: one of "train", "val", or "test"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
"""
def __len__(self) -> int:
"""The length of the dataset.
This is the total number of samples per epoch, and is used to define the
maximum allow index that can be passed to `__getitem__`.
"""
def __getitem__(self, index: int) -> dict[str, Tensor]:
"""A single sample from the dataset.
Load a single input image and target label or mask, and return it in a
dictionary.
"""
def plot(self) -> Figure:
"""Plot a sample of the dataset for visualization purposes.
This might involve selecting the RGB bands, using a colormap to display a mask,
adding a legend with class labels, etc.
"""
The data.py
file¶
The data.py
file is placed under tests/data/dataset_name/
directory and creates a smaller dummy dataset that replicates the features and formats of the actual full datasets for unit tests. This is needed to keep the tests fast (we don’t have time or storage space to download the real dataset) and to comply with the dataset license.
The script should:
replicate the directory structure and file names
replicate the file format, data type, and range of values
use the same compression scheme to simulate downloading the dataset
This is usually highly dependent on the dataset format and structure the new dataset comes in. You should always look for a similar dataset first and use that as a reference. However, below is an outline of the usual building blocks of a data.py
script, for example an image segmentation dataset with 10 classes.
[ ]:
import os
import shutil
import tempfile
import numpy as np
from PIL import Image
# Define the root directory and subdirectories
# Normally this would be the current directory (tests/data/my_new_dataset)
root_dir = os.path.join(tempfile.gettempdir(), 'my_new_dataset')
sub_dirs = ['image', 'target']
splits = ['train', 'val', 'test']
image_file_names = ['sample_1.png', 'sample_2.png', 'sample_3.png']
IMG_SIZE = 32
# Function to create dummy input images
def create_input_image(path: str, shape: tuple[int], pixel_values: list[int]) -> None:
data = np.random.choice(pixel_values, size=shape, replace=True).astype(np.uint8)
img = Image.fromarray(data)
img.save(path)
# Function to create dummy targets
def create_target_images(split: str, filename: str) -> None:
target_pixel_values = range(10)
path = os.path.join(root_dir, 'target', split, filename)
create_input_image(path, (IMG_SIZE, IMG_SIZE), target_pixel_values)
# Create a new clean version when re-running the script
if os.path.exists(root_dir):
shutil.rmtree(root_dir)
# Create the directory structure
for sub_dir in sub_dirs:
for split in splits:
os.makedirs(os.path.join(root_dir, sub_dir, split), exist_ok=True)
# Create dummy data for all splits and filenames
for split in splits:
for filename in image_file_names:
create_input_image(
os.path.join(root_dir, 'image', split, filename),
(IMG_SIZE, IMG_SIZE),
range(2**16),
)
create_target_images(split, filename.replace('_', '_target_'))
# Zip directory
shutil.make_archive(root_dir, 'zip', '.', root_dir)
The test_dataset_name.py
file¶
The test_dataset_name.py
file is placed under the tests/datasets/
directory. This file implements the unit tests for the dataset, such that every line of code in dataset_name.py
is tested. The logic of the individual test cases will likely be very similar to existing test files so you can look at those to to see how you can test the individual parts of the dataset logic.
[ ]:
import shutil
from pathlib import Path
import pytest
import torch
import torch.nn as nn
from _pytest.fixtures import SubRequest
from matplotlib import pyplot as plt
from pytest import MonkeyPatch
from torchgeo.datasets import DatasetNotFoundError
def download_url(url: str, root: str | Path, *args: str, **kwargs: str) -> None:
shutil.copy(url, root)
class TestMyNewDataset:
# pytest fixtures can be used to define variables to test different argument
# configurations to test, for example the different splits of the dataset
# or subselection of modalities/bands
@pytest.fixture(params=['train', 'val', 'test'])
def dataset(
self, monkeypatch: MonkeyPatch, tmp_path: Path, request: SubRequest
) -> MyNewDataset:
# monkeypatch can overwrite the class attributes defined above the __init__
# method and use the specific unit tests settings to mock behavior
split: str = request.param
transforms = nn.Identity()
return MyNewDataset(tmp_path, split=split, transforms=transforms, download=True)
def test_getitem(self, dataset: MyNewDataset) -> None:
# Retrieve a sample and check some of the desired properties
x = dataset[0]
assert isinstance(x, dict)
assert isinstance(x['image'], torch.Tensor)
assert isinstance(x['label'], torch.Tensor)
# For all additional class arguments, check behavior for invalid parameters
def test_invalid_split(self) -> None:
with pytest.raises(AssertionError):
MyNewDataset(foo='bar')
# Test the length of the dataset, this should coincide with the dummy data
def test_len(self, dataset: MyNewDataset) -> None:
assert len(dataset) == 2
# Test the logic when the dataset is already downloaded
def test_already_downloaded(self, dataset: MyNewDataset, tmp_path: Path) -> None:
MyNewDataset(root=tmp_path, download=True)
# Test the logic when the dataset is already downloaded but not extracted
def test_already_downloaded_not_extracted(
self, dataset: MyNewDataset, tmp_path: Path
) -> None:
shutil.rmtree(dataset.root)
download_url(dataset.url, root=tmp_path)
MyNewDataset(root=tmp_path, download=False)
# Test the logic when the dataset is not downloaded
def test_not_downloaded(self, tmp_path: Path) -> None:
with pytest.raises(DatasetNotFoundError, match='Dataset not found'):
MyNewDataset(tmp_path)
# Test the plotting method through something like the following
def test_plot(self, dataset: MyNewDataset) -> None:
x = dataset[0].copy()
x['prediction'] = x['label'].clone()
dataset.plot(x, suptitle='Test')
plt.close()
Documentation Entries¶
The entry point for new and experienced users of domain libraries is often the dedicated documentation page that accompanies a Github repository. TorchGeo uses the popular Sphinx framework to build its documentation. To display the documentation strings you have written in dataset_name.py
on the actual documentation page, you need to create an entry in docs/api/datasets.rst
in alphabetical order:
Dataset Name
^^^^^^^^^^^^
.. autoclass:: MyNewDataset
Additionally, add a row in the non_geo_datasets.csv
file under docs/api/datasets
to include the dataset in the overview table.
Linters¶
See the linter docs for an overview of linters that TorchGeo employs and how to apply them during commits.
Test Coverage¶
TorchGeo maintains a test coverage of 100%. This means, that every line of code written within the torchgeo directory is being run by some unit test. The testing docs provide instructions on how you can test the coverage locally for the dataset_new.py
file that you are adding.
Final Checklist¶
This final checklist might provide a useful overview of the individual parts discussed in this tutorial. You definitely do not need to check all boxes, before submitting a PR. If you have any questions feel free to ask on Slack or open a PR already such that maintainers or other community members can answer specific questions or give pointers. If you want to run your PR as a work of progress, such that the CI tests are run against your code while you work on ticking more boxes you can also convert the PR to a draft on the right side menu.
Dataset implementation in
dataset_name.py
Class docstring containing:
Summary intro
Dataset features
Dataset format
Link to publication
versionadded
tagif applicable a note on additional dependencies
all class methods have docstrings
all class methods have argument and return type hints, mypy (the tool that checks type hints) can be confusing at the beginning so don’t hesitate to ask for help
if dataset is on GitHub or Huggingface, url link should contain the commit hash
checksum added
plot method that can display a single sample from the dataset (you can add the resulting figure in your PR description)
add the dataset to
torchgeo/datastes/__init__.py
Add the copyright at the top of the file
Dummy data script
data.py
replicate directory structure
replicate naming of directory and files
for image based datasets, use a small size, like 32x32
Unit tests
test_dataset_name.py
100% test coverage
Documentation with
non_geo_datasets.csv
anddatasets.rst
entry in
datasets.rst
entry in
non_geo_datasets.csv
documentation displays properly, this can be checked locally or via the GitHub CI tests under
docs/readthedocs.org:torchgeo