Shortcuts

Source code for torchgeo.datasets.dl4gam

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""DL4GAMAlps Dataset."""

import pathlib
from collections.abc import Callable, Sequence
from typing import ClassVar

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib.figure import Figure
from torch import Tensor

from .errors import DatasetNotFoundError, RGBBandsMissingError
from .geo import NonGeoDataset
from .utils import (
    Path,
    download_and_extract_archive,
    download_url,
    extract_archive,
    lazy_import,
)


[docs]class DL4GAMAlps(NonGeoDataset): r"""A Multi-modal Dataset for Glacier Mapping (Segmentation) in the European Alps. The dataset consists of Sentinel-2 images from 2015 (mainly), 2016 and 2017, and binary segmentation masks for glaciers, based on an inventory built by glaciology experts (`Paul et al. 2020 <https://doi.org/10.1594/PANGAEA.909133>`_). Given that glacier ice is not always visible in the images, due to seasonal snow, shadow/cloud cover and, most importantly, debris cover, the dataset also includes additional features that can help in the segmentation task. Dataset features: * Sentinel-2 images (all bands, including cloud and shadow masks which can be used for loss masking) * glacier mask (0: no glacier, 1: glacier) * debris mask (0: no debris, 1: debris) based on a mix of three publications (`Scherler et al. 2018 <https://doi.org/10.5880/GFZ.3.3.2018.005>`_, `Herreid & Pellicciotti 2020 <https://doi.org/10.5281/zenodo.3866466>`_, `Linsbauer et al. 2021 <https://doi.glamos.ch/data/inventory/inventory_sgi2016_r2020.html>`_) * DEM (Copernicus GLO-30) + five derived features (using `xDEM <https://github.com/GlacioHack/xdem>`_): slope, aspect, terrain ruggedness index, planform and profile curvatures * dh/dt (surface elevation change) map over 2010-2015 (`Hugonnet et al. 2021 <https://doi.org/10.6096/13>`_) * v (surface velocity) map over 2015 (`ITS_LIVE <https://its-live.jpl.nasa.gov/>`_) Other specifications: * temporal coverage: one acquisition per glacier, from either 2015 (mainly), 2016, or 2017 * spatial coverage: only glaciers larger than 0.1 km\ :sup:`2`\ are considered (n=1593, after manual QC), totalling ~1685 km\ :sup:`2`\ which represents ~93% of the total inventory area for this region * 2251 patches sampled with overlap from the 1593 glaciers; or 11440 for the `large` version, obtained with an increased sampling overlap * the dataset download size is 5.8 GB (11 GB when unarchived); or 29.5 GB (52 GB when unarchived) for the `large` version * the dataset is provided at 10m GSD (after bilinearly resampling some of the Sentinel-2 bands and the additional features which come at a lower resolution) * the dataset provides fixed training, validation, and test geographical splits (70-10-20, by glacier area) * five different splits are provided, according to a five-fold cross-validation scheme * all the features/masks are stacked and provided as NetCDF files (one or more per glacier), structured as `data/{glacier_id}/{glacier_id}_{patch_number}_{center_x}_{center_y}.nc` * data is projected and geocoded in local UTM zones For more details check also: https://huggingface.co/datasets/dcodrut/dl4gam_alps If you use this dataset in your research, please cite the following paper: * https://doi.org/10.22541/essoar.173557607.70204641/v1 .. note:: This dataset requires the following additional libraries to be installed: * `xarray <https://pypi.org/project/xarray/>`_ * `netcdf4 <https://pypi.org/project/netCDF4/>`_ or `h5netcdf <https://pypi.org/project/h5netcdf/>`_ .. versionadded:: 0.7 """ url = 'https://huggingface.co/datasets/dcodrut/dl4gam_alps/resolve/7d20ca8a2b30c5518e086ffaa5ce37e6a66c42c1/data' download_metadata: ClassVar[dict[str, dict[str, str]]] = { 'dataset_small': { 'url': f'{url}/patches/inv_r_128_s_128.tar.gz', 'checksum': '3e69c47c6ff5106cd4ffaa6bb2caaaef', }, 'dataset_large': { 'url': f'{url}/patches/inv_r_128_s_64.tar.gz', 'checksum': '06e85a6a9e3dc6b3cdb07f928e832bc8', }, 'splits_csv': { 'url': f'{url}/map_all_splits_all_folds.csv', 'checksum': '862355c5c3482271dd171d31c70551b3', }, } rgb_bands = ('B4', 'B3', 'B2') all_bands = ( 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8', 'B8A', 'B9', 'B10', 'B11', 'B12', ) rgb_nir_swir_bands = ('B4', 'B3', 'B2', 'B8', 'B11') # the subset used in the paper valid_extra_features = ( 'dem', # Digital Elevation Model 'slope', 'aspect', 'planform_curvature', 'profile_curvature', 'terrain_ruggedness_index', # DEM-based features 'dhdt', # surface elevation change 'v', # surface velocity ) valid_splits = ('train', 'val', 'test') valid_versions = ('small', 'large') valid_cv_iters = (1, 2, 3, 4, 5)
[docs] def __init__( self, root: Path = 'data', split: str = 'train', cv_iter: int = 1, version: str = 'small', bands: Sequence[str] = rgb_nir_swir_bands, extra_features: Sequence[str] | None = None, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize the dataset. Args: root: root directory where dataset can be found split: one of "train", "val", or "test" cv_iter: one of 1, 2, 3, 4, 5 (for the five-fold geographical cross-validation scheme) version: one of "small" or "large" (controls the sampling overlap) bands: the Sentinel-2 bands to use as input (default: RGB + NIR + SWIR) extra_features: additional features to include (default: None; see the class attribute for the available) transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if any parameters are invalid. DatasetNotFoundError: if dataset is not found and *download* is False. DependencyNotFoundError: if xarray is not installed. """ lazy_import('xarray') self.root = pathlib.Path(root) self.split = split self.cv_iter = cv_iter self.version = version self.bands = bands self.extra_features = extra_features self.transforms = transforms self.download = download self.checksum = checksum # sanity checks assert split in self.valid_splits, f'Split {split} not in: {self.valid_splits}' assert cv_iter in self.valid_cv_iters, ( f'Cross-validation iteration {cv_iter} not in: {self.valid_cv_iters}' ) assert version in self.valid_versions, ( f'Version {version} not in: {self.valid_versions}' ) for band in bands: assert band in self.all_bands, f'Band {band} not in: {self.all_bands}' if extra_features: for feature in extra_features: assert feature in self.valid_extra_features, ( f'Feature {feature} not in: {self.valid_extra_features}' ) # set the local file paths label = f'dataset_{version}' self.fp_archive = self.root / f'{label}.tar.gz' self.dir_patches = self.root / label self.fp_splits_csv = self.root / 'splits.csv' # get the corresponding urls and checksums self.url_dataset = self.download_metadata[label]['url'] self.md5_dataset = self.download_metadata[label]['checksum'] self.url_csv_splits = self.download_metadata['splits_csv']['url'] self.md5_csv_splits = self.download_metadata['splits_csv']['checksum'] self._verify() self._prepare_files()
[docs] def __len__(self) -> int: """The length of the dataset. Returns: the number of patches in the dataset """ return len(self.fp_patches)
[docs] def __getitem__(self, index: int) -> dict[str, Tensor]: """Load the NetCDF file for the given index and return the sample as a dict. Args: index: index of the sample to return Returns: dict: a dictionary containing the sample with the following: * the Sentinel-2 image (selected bands) * the glacier mask (binary mask with all the glaciers in the current patch) * the debris mask * the cloud and shadow mask * the additional features (DEM, derived features, etc.) if required """ xr = lazy_import('xarray') nc = xr.open_dataset( self.fp_patches[index], decode_coords='all', mask_and_scale=True ) # extract the S2 image and masks from the netcdf file all_band_names = nc.band_data.long_name idx_img = [all_band_names.index(b) for b in self.bands] image = nc.band_data.isel(band=idx_img).values.astype(np.float32) id_cloud_mask = all_band_names.index('CLOUDLESS_MASK') mask_clouds_and_shadows = ~(nc.band_data.isel(band=id_cloud_mask).values == 1) sample = { 'image': torch.from_numpy(image), 'mask_glacier': torch.from_numpy(~np.isnan(nc.mask_all_g_id.values)), 'mask_debris': torch.from_numpy(nc.mask_debris.values == 1), 'mask_clouds_and_shadows': torch.from_numpy(mask_clouds_and_shadows), } # extract the additional features if needed if self.extra_features: for feature in self.extra_features: assert feature in nc, f'Feature {feature} not found in the netcdf file' vals = nc[feature].values.astype(np.float32) # impute the missing values with the mean # or zero (for dh/dt and surface velocity) v_fill = 0.0 if feature in ('dhdt', 'v') else np.nanmean(vals) vals[np.isnan(vals)] = v_fill sample[feature] = torch.from_numpy(vals) if self.transforms is not None: sample = self.transforms(sample) return sample
def _verify(self) -> None: """Verify the integrity of the dataset.""" if self.dir_patches.exists() and self.fp_splits_csv.exists(): return # check if the archive exists if self.fp_archive.exists(): extract_archive(self.fp_archive, self.dir_patches) return if not self.download: raise DatasetNotFoundError(self) self._download() def _download(self) -> None: """Download the patches and the csv with the splits.""" # download and extract the archive download_and_extract_archive( url=self.url_dataset, download_root=self.root, extract_root=self.dir_patches, filename=self.fp_archive.name, md5=self.md5_dataset if self.checksum else None, ) # download the splits csv download_url( url=self.url_csv_splits, root=self.root, filename='splits.csv', md5=self.md5_csv_splits if self.checksum else None, ) def _prepare_files(self) -> None: """Prepare the files for the dataset.""" # prepare the paths to the patches self.fp_patches = sorted(list(self.dir_patches.rglob('*.nc'))) # get the glacier IDs of the current split of the cross-validation self.df_splits = pd.read_csv(self.fp_splits_csv) fold_name = f'fold_{self.split if self.split != "val" else "valid"}' idx = self.df_splits[f'split_{self.cv_iter}'] == fold_name glacier_ids = list(self.df_splits.loc[idx, 'entry_id']) # filter the patches to keep only the ones corresponding to the current split self.fp_patches = [ fp for fp in self.fp_patches if fp.parent.name in glacier_ids ]
[docs] def plot( self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, clip_extrema: bool = True, ) -> Figure: """Plot a sample from the dataset. Args: sample: a sample returned by :meth:`DL4GAMAlps.__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle clip_extrema: flag indicating whether to clip the lowest/highest 2.5% of the values for contrast enhancement Returns: a matplotlib Figure with the rendered sample Raises: RGBBandsMissingError: If *bands* does not include all RGB bands. """ # we expect the RGB bands to be present if not {'B4', 'B3', 'B2'}.issubset(set(self.bands)): raise RGBBandsMissingError() nir_and_swir_present = {'B8', 'B11'}.issubset(set(self.bands)) # prepare the RGB image and the masks idx_rgb = [self.bands.index(b) for b in ['B4', 'B3', 'B2']] rgb_img = sample['image'][idx_rgb].permute(1, 2, 0) images = { 'RGB Image': rgb_img, 'Glacier Mask': sample['mask_glacier'], 'Debris Mask': sample['mask_debris'], 'Clouds and Shadows Mask': sample['mask_clouds_and_shadows'], } # add the SWIR-NIR-R image if the bands are present if nir_and_swir_present: idx_swir_nir_r = [self.bands.index(b) for b in ['B11', 'B8', 'B4']] swir_nir_r_img = sample['image'][idx_swir_nir_r].permute(1, 2, 0) images['SWIR-NIR-R Image'] = swir_nir_r_img # add the extra features if present for extra_v, title in ( ('prediction', 'Prediction'), ('dem', 'DEM'), ('slope', 'Slope'), ('aspect', 'Aspect'), ('planform_curvature', 'Planform Curvature'), ('profile_curvature', 'Profile Curvature'), ('terrain_ruggedness_index', 'Terrain Ruggedness Index'), ('dhdt', 'dh/dt'), ('v', 'Surface Velocity'), ): if extra_v in sample: images[title] = sample[extra_v] cmaps = { 'RGB Image': None, 'SWIR-NIR-R Image': None, 'Glacier Mask': 'gray', 'Prediction': 'gray', 'Debris Mask': 'gray', 'Clouds and Shadows Mask': 'gray', 'DEM': 'terrain', 'Slope': 'magma', 'Aspect': 'jet', 'Planform Curvature': 'magma', 'Profile Curvature': 'magma', 'Terrain Ruggedness Index': 'magma', 'dh/dt': 'seismic_r', 'Surface Velocity': 'magma', } # build the figure n_imgs = len(images) ncols = 4 if n_imgs <= 8 else 5 nrows = int(np.ceil(n_imgs / ncols)) fig, axs = plt.subplots(nrows, ncols, figsize=(5 * ncols, 5 * nrows)) for ax, k in zip(axs.flat, images): img = images[k].numpy() cmap = cmaps[k] # clip the extrema 5% of the values if needed if clip_extrema and k not in [ 'Glacier Mask', 'Prediction', 'Debris Mask', 'Clouds and Shadows Mask', ]: q_lim_clip = 0.025 img = np.clip( img, np.quantile(img, q_lim_clip), np.quantile(img, 1 - q_lim_clip) ) vmin, vmax = np.min(img), np.max(img) # scale the images to [0,1] if k in ['RGB Image', 'SWIR-NIR-R Image']: img = (img - vmin) / (vmax - vmin) if k == 'dh/dt': # diverging colormap for the dh/dt, make it symmetric max_abs = max(abs(vmin), abs(vmax)) vmin, vmax = -max_abs, max_abs ax.imshow(img, cmap=cmap, interpolation='none', vmin=vmin, vmax=vmax) if show_titles: ax.set_title(k) # disable the axes for all plots, including the empty plots for ax in axs.flat: ax.axis('off') if suptitle: fig.suptitle(suptitle) return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources