Shortcuts
Open in Studio Open in Colab
[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

Introduction to TorchGeo

Written by: Adam J. Stewart

Now that we’ve seen the basics of PyTorch and the challenges of working with geospatial data, let’s see how TorchGeo addresses these challenges.

Setup

First, we install TorchGeo and all of its dependencies.

[ ]:
%pip install torchgeo

Imports

Next, we import TorchGeo and any other libraries we need.

[ ]:
import os
import tempfile

from matplotlib import pyplot as plt
from torch.utils.data import DataLoader

from torchgeo.datasets import CDL, Landsat7, Landsat8, stack_samples
from torchgeo.datasets.utils import download_and_extract_archive
from torchgeo.samplers import GridGeoSampler, RandomGeoSampler

Motivation

Let’s start with a common task in geospatial machine learning to motivate us: land cover mapping. Imagine you have a collection of imagery and a land cover layer or mask you would like to learn to predict. In machine learning, this pixelwise classification process is referred to as semantic segmentation.

More concretely, imagine you would like to combine a set of Landsat 7 and 8 scenes with the Cropland Data Layer (CDL). This presents a number of challenges for a typical machine learning pipeline:

  • We may have hundreds of partially overlapping Landsat images that need to be mosaiced together

  • We have a single CDL mask covering the entire continental US

  • Neither the Landsat input or CDL output will have the same geospatial bounds

  • Landsat is multispectral, and may have a different resolution for each spectral band

  • Landsat 7 and 8 have a different number of spectral bands

  • Landsat and CDL may have a differerent CRS

  • Every single Landsat file may be in a different CRS (e.g., multiple UTM zones)

  • We may have multiple years of input and output data, and need to ensure matching time spans

We can’t have a dataset of length 1, and it isn’t obvious what to do when the number, bounds, and size of input images differ from the output masks. Furthermore, each image is far too large to pass to a neural network.

Traditionally, people either performed classification on a single pixel at a time or curated their own benchmark dataset. This works fine for training, but isn’t really useful for inference. What we would really like to be able to do is sample small pixel-aligned pairs of input images and output masks from the region of overlap between both datasets. This exact situation is illustrated in the following figure:

Landsat CDL intersection

Now, let’s see what features TorchGeo has to support this kind of use case.

Datasets

Geospatial data comes in a wide variety of formats. TorchGeo has two separate classes of datasets to deal with this dataset diversity:

  • NonGeoDataset: for curated benchmark datasets, where geospatial metadata is either missing or unnecessary

  • GeoDataset: for uncurated raster and vector data layers, where geospatial metadata is critical for merging datasets

We have already seen the former in the Introduction to PyTorch tutorial, as EuroSAT100 is a subclass of NonGeoDataset. In this tutorial, we will focus on the latter and its advantages for working with uncurated data.

Landsat

First, let’s start with our Landsat imagery. We will download a couple of Landsat 7 and 8 scenes, then pass them to builtin TorchGeo datasets for each.

[ ]:
landsat_root = os.path.join(tempfile.gettempdir(), 'landsat')

url = 'https://hf.co/datasets/torchgeo/tutorials/resolve/ff30b729e3cbf906148d69a4441cc68023898924/'
landsat7_url = url + 'LE07_L2SP_022032_20230725_20230820_02_T1.tar.gz'
landsat8_url = url + 'LC08_L2SP_023032_20230831_20230911_02_T1.tar.gz'

download_and_extract_archive(landsat7_url, landsat_root)
download_and_extract_archive(landsat8_url, landsat_root)

landsat7_bands = ['SR_B1', 'SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B7']
landsat8_bands = ['SR_B2', 'SR_B3', 'SR_B4', 'SR_B5', 'SR_B6', 'SR_B7']

landsat7 = Landsat7(paths=landsat_root, bands=landsat7_bands)
landsat8 = Landsat8(paths=landsat_root, bands=landsat8_bands)

print(landsat7)
print(landsat8)

print(landsat7.crs)
print(landsat8.crs)

The following details are worth noting:

  • We ignore the “coastal blue” band of Landsat 8 because it does not exist in Landsat 7

  • Even though all files are stored in the same directory, the datasets know which files to include

  • paths can be a directory to recursively search, a list of local files, or even a list of remote cloud assets

CDL

Next, let’s do the same for the CDL dataset. We are using a smaller cropped version of this dataset to make the download faster.

[ ]:
cdl_root = os.path.join(tempfile.gettempdir(), 'cdl')

cdl_url = url + '2023_30m_cdls.zip'

download_and_extract_archive(cdl_url, cdl_root)

cdl = CDL(paths=cdl_root)

print(cdl)
print(cdl.crs)

Again, the following details are worth noting:

  • We could actually ask the CDL dataset to download our data for us by adding download=True

  • All datasets have different spatial extents

  • All datasets have different CRSs

Composing datasets

We would like to be able to intelligently combine all three datasets in order to train a land cover mapping model. This requires us to create a virtual mosaic of all Landsat scenes, regardless of overlap. This can be done by taking the union of both datasets.

[ ]:
landsat = landsat7 | landsat8
print(landsat)
print(landsat.crs)

Similarly, we only want to sample from locations with both input imagery and output masks, not locations with only one or the other. We can achieve this by taking the intersection of both datasets.

[ ]:
dataset = landsat & cdl
print(dataset)
print(dataset.crs)

Note that all datasets now have the same CRS. When you run this code, you should notice it happen very quickly. TorchGeo hasn’t actually created a mosaic yet or reprojected anything, it will do this on the fly for us.

Spatiotemporal indexing

How did we do this? TorchGeo uses a GeoDataFrame to store the spatiotemporal bounding box of every file in the dataset. TorchGeo extracts the spatial bounding box from the metadata of each file, and the timestamp from the filename. This geospatial and geotemporal metadata allows us to efficiently compute the intersection or union of two datasets. It also lets us quickly retrieve an image and corresponding mask for a particular location in space and time.

[ ]:
size = 256

xmin = 925000
xmax = xmin + size * 30
ymin = 4470000
ymax = ymin + size * 30

sample = dataset[xmin:xmax, ymin:ymax]

landsat8.plot(sample)
cdl.plot(sample)
plt.show()

TorchGeo uses windowed-reading to only read the blocks of memory needed to load a small patch from a large raster tile. It also automatically reprojects all data to the same CRS and resolution (from the first dataset). This can be controlled by explicitly passing crs or res to the dataset.

Samplers

The above slice makes it easy to index into complex datasets consisting of hundreds of files. However, it is a bit cumbersome to manually construct these queries every time, especially if we want thousands or even millions of bounding boxes. Luckily, TorchGeo provides a GeoSampler class to construct these for us.

Random sampling

Usually, at training time, we want the largest possible dataset we can muster. For curated benchmark datasets like EuroSAT100, we achieved this by applying data augmentation to artificially inflate the size and diversity of our dataset. For GeoDataset objects, we can achieve this using random sampling. It doesn’t matter if two or more of our images have partial overlap, as long as they bring unique pixels that help our model learn.

TorchGeo provides a RandomGeoSampler to achieve this. We just tell the sampler how large we want each image patch to be (in pixel coordinates or CRS units) and, optionally, the number of image patches per epoch.

[ ]:
train_sampler = RandomGeoSampler(dataset, size=size, length=1000)
next(iter(train_sampler))

Gridded sampling

At evaluation time, this actually becomes a problem. We want to make sure we aren’t making multiple predictions for the same location. We also want to make sure we don’t miss any locations. To achieve this, TorchGeo also provides a GridGeoSampler. We can tell the sampler the size of each image patch and the stride of our sliding window.

[ ]:
test_sampler = GridGeoSampler(dataset, size=size, stride=size)
next(iter(test_sampler))

Data Loaders

All of these abstractions (GeoDataset and GeoSampler) are fully compatible with all of the rest of PyTorch. We can simply pass them to a data loader like below. Note that we also need the stack_samples collation function to convert a list of samples to a mini-batch.

[ ]:
train_dataloader = DataLoader(
    dataset, batch_size=128, sampler=train_sampler, collate_fn=stack_samples
)
test_dataloader = DataLoader(
    dataset, batch_size=128, sampler=test_sampler, collate_fn=stack_samples
)

Now that we have working data loaders, we can copy-n-paste our training code from the Introduction to PyTorch tutorial. We only need to change our model to one designed for semantic segmentation, such as a U-Net. Every other line of code would be identical to how you would do this in your normal PyTorch workflow.

Additional Reading

TorchGeo has plenty of other tutorials and documentation. If you would like to get more insight into the design of TorchGeo, the following external resources are also helpful:

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources