Shortcuts

Source code for torchgeo.datasets.caffe

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""CaFFe dataset."""

import glob
import os
import textwrap
from collections.abc import Callable
from typing import ClassVar

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.colors import ListedColormap
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, download_and_extract_archive, extract_archive


[docs]class CaFFe(NonGeoDataset): """CaFFe (CAlving Fronts and where to Find thEm) dataset. The `CaFFe <https://doi.pangaea.de/10.1594/PANGAEA.940950>`__ dataset is a semantic segmentation dataset of marine-terminating glaciers. Dataset features: * 13,090 train, 2,241 validation, and 3,761 test images * varying spatial resolution of 6-20m * paired binary calving front segmentation masks * paired multi-class land cover segmentation masks Dataset format: * images are single-channel pngs with dimension 512x512 * segmentation masks are single-channel pngs Dataset classes: 0. N/A 1. rock 2. glacier 3. ocean/ice melange If you use this dataset in your research, please cite the following paper: * https://essd.copernicus.org/articles/14/4287/2022/ .. versionadded:: 0.7 """ valid_splits = ('train', 'val', 'test') zipfilename = 'caffe.zip' data_dir = 'caffe' image_dir = 'sar_images' mask_dirs = ('fronts', 'zones') url = 'https://huggingface.co/datasets/torchgeo/caffe/resolve/cc96e8418981ce0f03afc9beace6422fdd7142c4/caffe.zip' md5 = '9a92fd6f05af74fbc41602595a55df0d' px_class_values_zones: ClassVar[dict[int, str]] = { 0: 'N/A', 64: 'rock', 127: 'glacier', 254: 'ocean/ice melange', } zone_class_colors = ('black', 'brown', 'lightgray', 'blue') zone_cmap = ListedColormap(zone_class_colors) px_class_values_fronts: ClassVar[dict[int, str]] = {0: 'no front', 255: 'front'}
[docs] def __init__( self, root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize a new instance of CaFFe dataset. Args: root: root directory where dataset can be found split: one of "train", "val", or "test" transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.valid_splits, f'split must be one of {self.valid_splits}' self.root = root self.split = split self.transforms = transforms self.download = download self.checksum = checksum self._verify() self.fpaths = glob.glob( os.path.join( self.root, self.zipfilename.replace('.zip', ''), self.mask_dirs[1], self.split, '*.png', ) ) self.ordinal_map_zones = torch.zeros( max(self.px_class_values_zones.keys()) + 1, dtype=torch.long ) for ordinal, px_class in enumerate(self.px_class_values_zones.keys()): self.ordinal_map_zones[px_class] = ordinal self.ordinal_map_fronts = torch.zeros( max(self.px_class_values_fronts.keys()) + 1, dtype=torch.long ) for ordinal, px_class in enumerate(self.px_class_values_fronts.keys()): self.ordinal_map_fronts[px_class] = ordinal
[docs] def __len__(self) -> int: """Return the number of images in the dataset.""" return len(self.fpaths)
[docs] def __getitem__(self, idx: int) -> dict[str, Tensor]: """Return the image and mask at the given index. Args: idx: index of the image and mask to return Returns: dict: a dict containing the image and mask """ zones_filename = os.path.basename(self.fpaths[idx]) img_filename = zones_filename.replace('_zones_', '_') front_filename = zones_filename.replace('_zones_', '_front_') def read_tensor(path: str) -> Tensor: return torch.from_numpy(np.array(Image.open(path))) img_path = os.path.join( self.root, self.data_dir, self.image_dir, self.split, img_filename ) img = read_tensor(img_path).unsqueeze(0).float() front_mask = read_tensor( os.path.join( self.root, self.data_dir, self.mask_dirs[0], self.split, front_filename ) ).long() zone_mask = read_tensor( os.path.join( self.root, self.data_dir, self.mask_dirs[1], self.split, zones_filename ) ).long() zone_mask = self.ordinal_map_zones[zone_mask] front_mask = self.ordinal_map_fronts[front_mask] sample = {'image': img, 'mask_front': front_mask, 'mask_zones': zone_mask} if self.transforms: sample = self.transforms(sample) return sample
def _verify(self) -> None: """Verify the integrity of the dataset.""" exists = [] if os.path.exists( os.path.join( self.root, self.zipfilename.replace('.zip', ''), self.image_dir, self.split, ) ): exists.append(True) else: exists.append(False) for mask_dir in self.mask_dirs: if os.path.exists( os.path.join( self.root, self.zipfilename.replace('.zip', ''), mask_dir, self.split, ) ): exists.append(True) else: exists.append(False) if all(exists): return # check download of zipfile if os.path.exists(os.path.join(self.root, self.zipfilename)): self._extract() return if not self.download: raise DatasetNotFoundError(self) self._download() def _download(self) -> None: """Download the dataset.""" download_and_extract_archive( self.url, self.root, filename=self.zipfilename, md5=self.md5 if self.checksum else None, ) def _extract(self) -> None: """Extract the dataset.""" extract_archive(os.path.join(self.root, self.zipfilename), self.root)
[docs] def plot( self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. Args: sample: a sample returned by :meth:`CaFFe.__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle Returns: a matplotlib Figure with the rendered sample """ if 'prediction' in sample: ncols = 4 else: ncols = 3 fig, axs = plt.subplots(1, ncols, figsize=(15, 5)) axs[0].imshow(sample['image'].permute(1, 2, 0).numpy()) axs[0].axis('off') axs[1].imshow(sample['mask_front'].numpy(), cmap='gray') axs[1].axis('off') unique_classes = np.unique(sample['mask_zones'].numpy()) axs[2].imshow(sample['mask_zones'].numpy(), cmap=self.zone_cmap) axs[2].axis('off') handles = [ mpatches.Patch( color=self.zone_cmap(ordinal), label='\n'.join( textwrap.wrap(self.px_class_values_zones[px_class], width=10) ), ) for ordinal, px_class in enumerate(self.px_class_values_zones.keys()) if ordinal in unique_classes ] axs[2].legend(handles=handles, loc='upper right', bbox_to_anchor=(1.4, 1)) if show_titles: axs[0].set_title('Image') axs[1].set_title('Front Mask') axs[2].set_title('Zone Mask') if 'prediction' in sample: axs[3].imshow(sample['prediction'].numpy(), cmap='gray') axs[3].axis('off') if show_titles: axs[3].set_title('Prediction') if suptitle: fig.suptitle(suptitle) return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources