Shortcuts
Open in Studio Open in Colab
[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

Earth Water Surface

Written by: Mauricio Cordeiro

Introduction

The objective of this tutorial is to go through the Earth Water Surface dataset and cover the following topics:

  • Creating RasterDatasets, DataLoaders and Samplers for images and masks;

  • Intersection Dataset;

  • Normalizing the data;

  • Creating spectral indices;

  • Creating the segmentation model (DeepLabV3);

  • Loss function and metrics; and

  • Training loop.

Environment

For the environment, we will install the torchgeo and scikit-learn packages.

[ ]:
%pip install torchgeo scikit-learn

Imports

[ ]:
import tempfile
from collections.abc import Callable, Iterable
from pathlib import Path

import kornia.augmentation as K
import matplotlib.pyplot as plt
import numpy as np
import rasterio as rio
import torch
from sklearn.metrics import jaccard_score
from torch.utils.data import DataLoader

from torchgeo.datasets import RasterDataset, stack_samples, unbind_samples, utils
from torchgeo.samplers import RandomGeoSampler, Units
from torchgeo.transforms import indices

Dataset

The dataset we will use is the Earth Surface Water dataset [1] (licensed under Creative Commons Attribution 4.0 International Public License), which has patches from different parts of the world (Figure below) and its corresponding water masks. The dataset uses optical imagery from Sentinel-2 satellite with 10m of spatial resolution.

Image1

[1] Xin Luo. (2021). Earth Surface Water Dataset [Data set]. Zenodo. https://doi.org/10.5281/zenodo.5205674

[ ]:
# Download and extract dataset to a temp folder
tmp_path = Path(tempfile.gettempdir()) / 'surface_water/'
utils.download_and_extract_archive(
    'https://hf.co/datasets/cordmaur/earth_surface_water/resolve/main/earth_surface_water.zip',
    tmp_path,
)

# Set the root to the extracted folder
root = tmp_path / 'dset-s2'

Creating the Datasets

Now that we have the original dataset already uncompressed in Colab’s environment, we can prepare it to be loaded into a neural network. For that, we will create an instance of the RasterDataset class, provided by TorchGeo, and point to the specific directory, using the following commands. The scale function will apply the 1e-4 scale necessary to get the Sentinel-2 values in reflectance. Once the datasets are created, we can combine images with masks (labels) using the & operator.

[ ]:
def scale(item: dict):
    item['image'] = item['image'] / 10000
    return item
[ ]:
train_imgs = RasterDataset(
    paths=(root / 'tra_scene').as_posix(), crs='epsg:3395', res=10, transforms=scale
)
train_msks = RasterDataset(
    paths=(root / 'tra_truth').as_posix(), crs='epsg:3395', res=10
)

valid_imgs = RasterDataset(
    paths=(root / 'val_scene').as_posix(), crs='epsg:3395', res=10, transforms=scale
)
valid_msks = RasterDataset(
    paths=(root / 'val_truth').as_posix(), crs='epsg:3395', res=10
)

# IMPORTANT
train_msks.is_image = False
valid_msks.is_image = False

train_dset = train_imgs & train_msks
valid_dset = valid_imgs & valid_msks

# Create the samplers

train_sampler = RandomGeoSampler(train_imgs, size=512, length=130, units=Units.PIXELS)
valid_sampler = RandomGeoSampler(valid_imgs, size=512, length=64, units=Units.PIXELS)

Note that we are specifying the CRS (Coordinate Reference System) to EPSG:3395. TorchGeo requires that all the images are loaded in the same CRS. However, the patches in the dataset are in different UTM projections and the default behavior of TorchGeo is to use the first CRS found as its default. In this case, we have to inform a CRS that is able to cope with these different regions around the globe. To minimize the deformations due to the huge differences in latitude (I can create a history specific for this purpose) within the patches, I have selected World Mercator as the main CRS for the project. Figure 3 shows the world projected in World Mercator CRS.

Image2

Understanding the sampler

To create training patches that can be fed into a neural network from our dataset, we need to select samples of fixed sizes. TorchGeo has many samplers, but here we will use the RandomGeoSampler class. Basically, the sampler selects random bounding boxes of fixed size that belongs to the original image. Then, these bounding boxes are used in the RasterDataset to query the portion of the image we want. Here is an example using the previously created samplers.

[ ]:
bbox = next(iter(train_sampler))
bbox
[ ]:
sample = train_dset[bbox]
sample.keys()
[ ]:
sample['image'].shape, sample['mask'].shape

Notice we have now patches of same size (…, 512 x 512)

Creating Dataloaders

Creating a DataLoader in TorchGeo is very straightforward, just like it is with Pytorch (we are actually using the same class). Note below that we are also using the same samplers already defined. Additionally we inform the dataset that the dataloader will use to pull data from, the batch_size (number of samples in each batch) and a collate function that specifies how to “concatenate” the multiple samples into one single batch.

Finally, we can iterate through the dataloader to grab batches from it. To test it, we will get the first batch.

[ ]:
# Adjust the batch size according to your GPU memory
train_dataloader = DataLoader(
    train_dset, sampler=train_sampler, batch_size=4, collate_fn=stack_samples
)
valid_dataloader = DataLoader(
    valid_dset, sampler=valid_sampler, batch_size=4, collate_fn=stack_samples
)

train_batch = next(iter(train_dataloader))
valid_batch = next(iter(valid_dataloader))
train_batch.keys(), valid_batch.keys()

Batch Visualization

Now that we can draw batches from our datasets, let’s create a function to display the batches.

The function plot_batch will will check automatically the number of items in the batch and if there are masks associated to arrange the output grid accordingly.

[ ]:
def plot_imgs(
    images: Iterable, axs: Iterable, chnls: list[int] = [2, 1, 0], bright: float = 3.0
):
    for img, ax in zip(images, axs):
        arr = torch.clamp(bright * img, min=0, max=1).numpy()
        rgb = arr.transpose(1, 2, 0)[:, :, chnls]
        ax.imshow(rgb)
        ax.axis('off')


def plot_msks(masks: Iterable, axs: Iterable):
    for mask, ax in zip(masks, axs):
        ax.imshow(mask.squeeze().numpy(), cmap='Blues')
        ax.axis('off')


def plot_batch(
    batch: dict,
    bright: float = 3.0,
    cols: int = 4,
    width: int = 5,
    chnls: list[int] = [2, 1, 0],
):
    # Get the samples and the number of items in the batch
    samples = unbind_samples(batch.copy())

    # if batch contains images and masks, the number of images will be doubled
    n = 2 * len(samples) if ('image' in batch) and ('mask' in batch) else len(samples)

    # calculate the number of rows in the grid
    rows = n // cols + (1 if n % cols != 0 else 0)

    # create a grid
    _, axs = plt.subplots(rows, cols, figsize=(cols * width, rows * width))

    if ('image' in batch) and ('mask' in batch):
        # plot the images on the even axis
        plot_imgs(
            images=map(lambda x: x['image'], samples),
            axs=axs.reshape(-1)[::2],
            chnls=chnls,
            bright=bright,
        )

        # plot the masks on the odd axis
        plot_msks(masks=map(lambda x: x['mask'], samples), axs=axs.reshape(-1)[1::2])

    else:
        if 'image' in batch:
            plot_imgs(
                images=map(lambda x: x['image'], samples),
                axs=axs.reshape(-1),
                chnls=chnls,
                bright=bright,
            )

        elif 'mask' in batch:
            plot_msks(masks=map(lambda x: x['mask'], samples), axs=axs.reshape(-1))
[ ]:
plot_batch(train_batch)

Data Standardization and Spectral Indices

Normally, machine learning methods (deep learning included) benefit from feature scaling. That means standard deviation around 1 and zero mean, by applying the following formula: \(X'=\frac{X-Mean}{\text{Standard deviation}}\)

To do that, we need to first find the mean and standard deviation for each one of the 6s channels in the dataset.

Let’s define a function calculate these statistics and write its results in the variables mean and std. We will use our previously installed rasterio package to open the images and perform a simple average over the statistics for each batch/channel. For the standard deviation, this method is an approximation. For a more precise calculation, please refer to: http://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.htm.

[ ]:
def calc_statistics(dset: RasterDataset):
    """
    Calculate the statistics (mean and std) for the entire dataset
    Warning: This is an approximation. The correct value should take into account the
    mean for the whole dataset for computing individual stds.
    For correctness I suggest checking: http://notmatthancock.github.io/2017/03/23/simple-batch-stat-updates.html
    """

    # To avoid loading the entire dataset in memory, we will loop through each img
    # The filenames will be retrieved from the dataset's GeoDataFrame index
    files = dset.index.filepath

    # Resetting statistics
    accum_mean = 0
    accum_std = 0

    for file in files:
        img = rio.open(file).read() / 10000  # type: ignore
        accum_mean += img.reshape((img.shape[0], -1)).mean(axis=1)
        accum_std += img.reshape((img.shape[0], -1)).std(axis=1)

    # at the end, we shall have 2 vectors with length n=chnls
    # we will average them considering the number of images
    return accum_mean / len(files), accum_std / len(files)
[ ]:
# Calculate the statistics (Mean and std) for the dataset
mean, std = calc_statistics(train_imgs)

# Please, note that we will create spectral indices using the raw (non-normalized) data. Then, when normalizing, the sensors will have more channels (the indices) that should not be normalized.
# To solve this, we will add the indices to the 0's to the mean vector and 1's to the std vectors
mean = np.concat([mean, [0, 0, 0]])
std = np.concat([std, [1, 1, 1]])

norm = K.Normalize(mean=mean, std=std)

tfms = torch.nn.Sequential(
    indices.AppendNDWI(index_green=1, index_nir=3),
    indices.AppendNDWI(index_green=1, index_nir=5),
    indices.AppendNDVI(index_nir=3, index_red=2),
    norm,
)
[ ]:
transformed_img = tfms(train_batch['image'])
print(transformed_img.shape)

Note that our transformed batch has now 9 channels, instead of 6.

Important: the normalize method we created will apply the normalization just to the original bands and it will ignore the previously appended indices. That’s important to avoid errors due to distinct shapes between the batch and the mean and std vectors.

Segmentation Model

For the semantic segmentation model, we are going to use a predefined architecture that is available in Pytorch. Looking at list (https://pytorch.org/vision/stable/models.html#semantic-segmentation) it is possible to note 3 models available for semantic segmentation, but one (LRASPP) is intended for mobile applications. In our tutorial, we will use the DeepLabV3 model.

Here, we will create a DeepLabV3 model for 2 classes. In this case, I will skip the pretrained weights, as the weights represent another domain (not water segmentation from multispectral imagery).

[ ]:
from torchvision.models.segmentation import deeplabv3_resnet50

model = deeplabv3_resnet50(weights=None, num_classes=2)
model

The first thing we have to pay attention in the model architecture is the number of channels expected in the first convolution (Conv2d), that is defined as 3. That’s because the model is prepared to work with RGB images. After the first convolution, the 3 channels will produce 64 channels in lower resolution, and so on. As we have now 9 channels, we will change this first processing layer to adapt correctly to our model. We can do this by replacing the first convolutional layer for a new one, by following the commands. Finally, we check a mock batch can pass through the model and provide the output with 2 channels (water / no_water) as desired.

[ ]:
backbone = model.get_submodule('backbone')

conv = torch.nn.modules.conv.Conv2d(
    in_channels=9,
    out_channels=64,
    kernel_size=(7, 7),
    stride=(2, 2),
    padding=(3, 3),
    bias=False,
)
backbone.register_module('conv1', conv)

pred = model(torch.randn(3, 9, 512, 512))
pred['out'].shape

Training Loop

The training function should receive the number of epochs, the model, the dataloaders, the loss function (to be optimized) the accuracy function (to assess the results), the optimizer (that will adjust the parameters of the model in the correct direction) and the transformations to be applied to each batch.

[ ]:
# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device
[ ]:
def train_loop(
    epochs: int,
    train_dl: DataLoader,
    val_dl: DataLoader | None,
    model: torch.nn.Module,
    loss_fn: Callable,
    optimizer: torch.optim.Optimizer,
    acc_fns: list | None = None,
    batch_tfms: Callable | None = None,
):
    # size = len(dataloader.dataset)
    cuda_model = model.to(device)

    for epoch in range(epochs):
        accum_loss = 0
        for batch in train_dl:
            if batch_tfms is not None:
                X = batch_tfms(batch['image']).to(device)
            else:
                X = batch['image'].to(device)

            y = batch['mask'].type(torch.long).to(device)
            pred = cuda_model(X)['out']
            loss = loss_fn(pred, y)

            # BackProp
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update the accum loss
            accum_loss += float(loss) / len(train_dl)

        # Testing against the validation dataset
        if acc_fns is not None and val_dl is not None:
            # reset the accuracies metrics
            acc = [0.0] * len(acc_fns)

            with torch.no_grad():
                for batch in val_dl:
                    if batch_tfms is not None:
                        X = batch_tfms(batch['image']).to(device)
                    else:
                        X = batch['image'].type(torch.float32).to(device)

                    y = batch['mask'].type(torch.long).to(device)

                    pred = cuda_model(X)['out']

                    for i, acc_fn in enumerate(acc_fns):
                        acc[i] = float(acc[i] + acc_fn(pred, y) / len(val_dl))

            # at the end of the epoch, print the errors, etc.
            print(
                f'Epoch {epoch}: Train Loss={accum_loss:.5f} - Accs={[round(a, 3) for a in acc]}'
            )
        else:
            print(f'Epoch {epoch}: Train Loss={accum_loss:.5f}')

Loss and Accuracy Functions

For the loss function, normally the Cross Entropy Loss should work, but it requires the mask to have shape (N, d1, d2). In this case, we will need to squeeze our second dimension manually.

[ ]:
def oa(pred, y):
    flat_y = y.squeeze()
    flat_pred = pred.argmax(dim=1)
    acc = torch.count_nonzero(flat_y == flat_pred) / torch.numel(flat_y)
    return acc


def iou(pred, y):
    flat_y = y.cpu().numpy().squeeze()
    flat_pred = pred.argmax(dim=1).detach().cpu().numpy()
    return jaccard_score(flat_y.reshape(-1), flat_pred.reshape(-1), zero_division=1.0)


def loss(p, t):
    return torch.nn.functional.cross_entropy(p, t.squeeze())

Training

To train the model it is important to have CUDA GPUs available. In Colab, it can be done by changing the runtime type and re-running the notebook.

[ ]:
# adjust number of epochs depending on the device
if torch.cuda.is_available():
    num_epochs = 2
else:
    # if GPU is not available, just make 1 pass and limit the size of the datasets
    num_epochs = 1

    # by limiting the length of the sampler we limit the iterations in each epoch
    train_dataloader.sampler.length = 8
    valid_dataloader.sampler.length = 8

# train the model
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.01)
train_loop(
    num_epochs,
    train_dataloader,
    valid_dataloader,
    model,
    loss,
    optimizer,
    acc_fns=[oa, iou],
    batch_tfms=tfms,
)

Additional Reading

This tutorial is also available as a 3 parts Medium story: https://medium.com/towards-data-science/artificial-intelligence-for-geospatial-analysis-with-pytorchs-torchgeo-part-1-52d17e409f09

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources