Shortcuts

Source code for torchgeo.datasets.geonrw

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""GeoNRW dataset."""

import os
from collections.abc import Callable
from glob import glob
from typing import ClassVar

import matplotlib
import matplotlib.cm
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor
from torchvision import transforms

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, download_and_extract_archive, extract_archive


[docs]class GeoNRW(NonGeoDataset): """GeoNRW dataset. This datasets contains RGB, DEM and segmentation label data from North Rhine-Westphalia, Germany. Dataset features: * 7298 training and 485 test samples * RGB images, 1000x1000px normalized to [0, 1] * DEM images, unnormalized * segmentation labels Dataset format: * RGB images are three-channel jp2 * DEM images are single-channel tif * segmentation labels are single-channel tif Dataset classes: 0. background 1. forest 2. water 3. agricultural 4. residential,commercial,industrial 5. grassland,swamp,shrubbery 6. railway,trainstation 7. highway,squares 8. airport,shipyard 9. roads 10. buildings Additional information about the dataset can be found `on this site <https://ieee-dataport.org/open-access/geonrw>`__. If you use this dataset in your research, please cite the following paper: * https://ieeexplore.ieee.org/document/9406194 .. versionadded:: 0.6 """ # Splits taken from https://github.com/gbaier/geonrw/blob/ecfcdbca8cfaaeb490a9c6916980f385b9f3941a/pytorch/nrw.py#L48 splits = ('train', 'test') train_list: tuple[str, ...] = ( 'aachen', 'bergisch', 'bielefeld', 'bochum', 'bonn', 'borken', 'bottrop', 'coesfeld', 'dortmund', 'dueren', 'duisburg', 'ennepetal', 'erftstadt', 'essen', 'euskirchen', 'gelsenkirchen', 'guetersloh', 'hagen', 'hamm', 'heinsberg', 'herford', 'hoexter', 'kleve', 'koeln', 'krefeld', 'leverkusen', 'lippetal', 'lippstadt', 'lotte', 'moenchengladbach', 'moers', 'muelheim', 'muenster', 'oberhausen', 'paderborn', 'recklinghausen', 'remscheid', 'siegen', 'solingen', 'wuppertal', ) test_list: tuple[str, ...] = ('duesseldorf', 'herne', 'neuss') classes = ( 'background', 'forest', 'water', 'agricultural', 'residential,commercial,industrial', 'grassland,swamp,shrubbery', 'railway,trainstation', 'highway,squares', 'airport,shipyard', 'roads', 'buildings', ) colormap = mcolors.ListedColormap( [ '#000000', # matplotlib black for background '#2ca02c', # matplotlib green for forest '#1f77b4', # matplotlib blue for water '#8c564b', # matplotlib brown for agricultural '#7f7f7f', # matplotlib gray residential_commercial_industrial '#bcbd22', # matplotlib olive for grassland_swamp_shrubbery '#ff7f0e', # matplotlib orange for railway_trainstation '#9467bd', # matplotlib purple for highway_squares '#17becf', # matplotlib cyan for airport_shipyard '#d62728', # matplotlib red for roads '#e377c2', # matplotlib pink for buildings ] ) readers: ClassVar[dict[str, Callable[[str], Image.Image]]] = { 'rgb': lambda path: Image.open(path).convert('RGB'), 'dem': lambda path: Image.open(path).copy(), 'seg': lambda path: Image.open(path).convert('I;16'), } modality_filenames: ClassVar[dict[str, Callable[[list[str]], str]]] = { 'rgb': lambda utm_coords: '{}_{}_rgb.jp2'.format(*utm_coords), 'dem': lambda utm_coords: '{}_{}_dem.tif'.format(*utm_coords), 'seg': lambda utm_coords: '{}_{}_seg.tif'.format(*utm_coords), } modalities: tuple[str, ...] = ('rgb', 'dem', 'seg') url = 'https://hf.co/datasets/torchgeo/geonrw/resolve/3cb6bdf2a615b9e526c7dcff85fd1f20728081b7/{}' filename = 'nrw_dataset.tar.gz' md5 = 'd56ab50098d5452c33d08ff4e99ce281'
[docs] def __init__( self, root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize the GeoNRW dataset. Args: root: root directory where dataset can be found split: one of "train", or "test" transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if ``split`` argument is invalid DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.splits, f'split must be one of {self.splits}' self.root = root self.split = split self.transforms = transforms self.download = download self.checksum = checksum self.city_names = self.test_list if split == 'test' else self.train_list self._verify() self.file_list = self._get_file_list()
def _get_file_list(self) -> list[str]: """Get a list of files for cities in the dataset split. Returns: list of filenames in the dataset split """ file_list: list[str] = [] for cn in self.city_names: pattern = os.path.join(self.root, cn, '*rgb.jp2') file_list.extend(glob(pattern)) return sorted(file_list)
[docs] def __len__(self) -> int: """Return the number of data points in the dataset. Returns: length of the dataset """ return len(self.file_list)
[docs] def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: index: index to return Returns: data and label at that index """ to_tensor = transforms.ToTensor() path: str = self.file_list[index] utm_coords = os.path.basename(path).split('_')[:2] base_dir = os.path.dirname(path) sample: dict[str, Tensor] = {} for modality in self.modalities: modality_path = os.path.join( base_dir, self.modality_filenames[modality](utm_coords) ) sample[modality] = to_tensor(self.readers[modality](modality_path)) # rename to torchgeo standard keys sample['image'] = sample.pop('rgb').float() sample['mask'] = sample.pop('seg').long().squeeze(0) if self.transforms: sample = self.transforms(sample) return sample
def _verify(self) -> None: """Verify the integrity of the dataset.""" # check if city names directories exist all_exist = all( os.path.exists(os.path.join(self.root, cn)) for cn in self.city_names ) if all_exist: return # Check if the tar file has been downloaded if os.path.exists(os.path.join(self.root, self.filename)): extract_archive(os.path.join(self.root, self.filename), self.root) return # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) # Download the dataset self._download() def _download(self) -> None: """Download the dataset.""" download_and_extract_archive( self.url.format(self.filename), download_root=self.root, md5=self.md5 if self.checksum else None, )
[docs] def plot( self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. Args: sample: a sample returned by :meth:`__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional suptitle to use for figure Returns: a matplotlib Figure with the rendered sample """ showing_predictions = 'prediction' in sample ncols = 3 if showing_predictions: prediction = sample['prediction'].long() ncols += 1 fig, axs = plt.subplots( nrows=1, ncols=ncols, figsize=(ncols * 5, 10), sharex=True ) axs[0].imshow(sample['image'].permute(1, 2, 0)) axs[0].axis('off') axs[1].imshow(sample['dem'].squeeze(0), cmap='gray') axs[1].axis('off') axs[2].imshow( sample['mask'].squeeze(0), self.colormap, vmin=0, vmax=10, interpolation='none', ) axs[2].axis('off') if showing_predictions: axs[3].imshow( prediction.squeeze(0), self.colormap, vmin=0, vmax=10, interpolation='none', ) # show classes in legend if show_titles: patches = [matplotlib.patches.Patch(color=c) for c in self.colormap.colors] # type: ignore axs[2].legend( patches, self.classes, loc='center left', bbox_to_anchor=(1, 0.5) ) if show_titles: axs[0].set_title('RGB Image') axs[1].set_title('DEM') axs[2].set_title('Labels') if suptitle is not None: fig.suptitle(suptitle, y=0.8) fig.tight_layout() return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources