Shortcuts

Source code for torchgeo.datasets.copernicus.base

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""Copernicus-Bench abstract base class."""

import os
import re
from abc import ABC, abstractmethod
from collections.abc import Callable, Sequence
from typing import Any, Literal

import matplotlib.colors
import numpy as np
import pandas as pd
import rasterio as rio
import torch
from einops import rearrange
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from pyproj import Transformer
from torch import Tensor

from torchgeo.datasets.geo import NonGeoDataset

from ..errors import DatasetNotFoundError, RGBBandsMissingError
from ..utils import (
    Path,
    array_to_tensor,
    disambiguate_timestamp,
    download_and_extract_archive,
    extract_archive,
    percentile_normalization,
)


[docs]class CopernicusBenchBase(NonGeoDataset, ABC): """Abstract base class for all Copernicus-Bench datasets. If you use this dataset in your research, please cite the following paper: * https://arxiv.org/abs/2503.11849 .. versionadded:: 0.7 """ @property @abstractmethod def url(self) -> str: """Download URL.""" #: MD5 checksum. md5: str #: Zip file name. zipfile: str #: Subdirectory containing split files. directory: str #: Filename format of split files. filename = '{}.csv' #: Mask dtype to cast to, either torch.long for classification #: or torch.float for regression. dtype: torch.dtype = torch.long #: Regular expression used to extract date from filename. filename_regex = '.*' #: Date format string used to parse date from filename. date_format = '%Y%m%dT%H%M%S' @property @abstractmethod def all_bands(self) -> tuple[str, ...]: """All spectral channels.""" @property @abstractmethod def rgb_bands(self) -> tuple[str, ...]: """Spectral channels used to make RGB plots.""" #: Matplotlib color map for semantic segmentation and change detection plots. cmap: str | matplotlib.colors.Colormap #: List of classes for classification, semantic segmentation, and change detection. classes: tuple[str, ...]
[docs] def __init__( self, root: Path = 'data', split: Literal['train', 'val', 'test'] = 'train', bands: Sequence[str] | None = None, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize a new CopernicusBenchBase instance. Args: root: Root directory where dataset can be found. split: One of 'train', 'val', or 'test'. bands: Sequence of band names to load (defaults to all bands). transforms: A function/transform that takes input sample and its target as entry and returns a transformed version. download: If True, download dataset and store it in the root directory. checksum: If True, check the MD5 of the downloaded files (may be slow). Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ self.root = root self.split = split self.bands = bands or self.all_bands self.band_indices = [self.all_bands.index(i) + 1 for i in self.bands] self.transforms = transforms self.download = download self.checksum = checksum self._verify() filepath = os.path.join(root, self.directory, self.filename.format(split)) self.files = pd.read_csv(filepath, header=None)[0]
[docs] def __len__(self) -> int: """Return the length of the dataset. Returns: Length of the dataset. """ return len(self.files)
def _load_image(self, path: str) -> dict[str, Tensor]: """Load an image and metadata. Args: path: File path to load. Returns: An image sample. """ sample: dict[str, Tensor] = {} with rio.open(path) as f: # Image image = f.read(self.band_indices).astype(np.float32) sample['image'] = torch.tensor(image) # Location if f.transform != rio.Affine.identity(): x = (f.bounds.left + f.bounds.right) / 2 y = (f.bounds.bottom + f.bounds.top) / 2 transformer = Transformer.from_crs(f.crs, 'epsg:4326', always_xy=True) lon, lat = transformer.transform(x, y) sample['lat'] = torch.tensor(lat) sample['lon'] = torch.tensor(lon) # Time if match := re.match(self.filename_regex, os.path.basename(path)): if 'date' in match.groupdict(): date_str = match.group('date') mint, maxt = disambiguate_timestamp(date_str, self.date_format) time = (mint.timestamp() + maxt.timestamp()) / 2 sample['time'] = torch.tensor(time) elif 'start' in match.groupdict() and 'stop' in match.groupdict(): start = match.group('start') stop = match.group('stop') mint, _ = disambiguate_timestamp(start, self.date_format) _, maxt = disambiguate_timestamp(stop, self.date_format) time = (mint.timestamp() + maxt.timestamp()) / 2 sample['time'] = torch.tensor(time) return sample def _load_mask(self, path: str) -> dict[str, Tensor]: """Load a target mask. Args: path: File path to load. Returns: A target sample. """ sample: dict[str, Tensor] = {} with rio.open(path) as f: sample['mask'] = array_to_tensor(f.read(1)).to(self.dtype) return sample def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the files already exist path = os.path.join(self.root, self.directory, self.filename.format(self.split)) if os.path.exists(path): return # Check if the zip file already exists (if so then extract) if os.path.exists(os.path.join(self.root, self.zipfile)): self._extract() return # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) # Download and extract the dataset self._download() def _extract(self) -> None: """Extract the dataset.""" extract_archive(os.path.join(self.root, self.zipfile)) def _download(self) -> None: """Download the dataset.""" md5 = self.md5 if self.checksum else None download_and_extract_archive(self.url, self.root, md5=md5)
[docs] def plot( self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. Args: sample: A sample returned by :meth:`NonGeoDataset.__getitem__`. show_titles: Flag indicating whether to show titles above each panel. suptitle: Optional string to use as a suptitle. Returns: A matplotlib Figure with the rendered sample. Raises: RGBBandsMissingError: If *bands* does not include all RGB bands. """ rgb_indices = [] for band in self.rgb_bands: if band in self.bands: rgb_indices.append(self.bands.index(band)) else: raise RGBBandsMissingError() # Static -> time series images = sample['image'].numpy() if sample['image'].dim() == 3: images = np.expand_dims(images, axis=0) ncols = len(images) if 'mask' in sample: ncols += 1 if 'prediction' in sample: ncols += 1 fig, ax = plt.subplots(ncols=ncols, squeeze=False) # Label title = 'Image' if 'label' in sample: if sample['label'].dim() == 0: # Multiclass classification label: Any = self.classes[sample['label']] if 'prediction' in sample: prediction: Any = self.classes[sample['prediction']] else: # Multilabel classification label = sample['label'].numpy().nonzero()[0] if 'prediction' in sample: prediction = sample['prediction'].numpy().nonzero()[0] title = f'Label: {label}' if 'prediction' in sample: title += f'\nPrediction: {prediction}' # Image images = images[:, rgb_indices] if set(self.rgb_bands) <= {'VV', 'VH', 'HH', 'HV'}: # SAR vv = images[:, 0] vh = images[:, 1] images = np.stack([vv, vh, (vv + vh) / 2], axis=1) images = percentile_normalization(images) images = percentile_normalization(images) images = rearrange(images, 't c h w -> t h w c') for i in range(len(images)): ax[0, i].imshow(images[i]) ax[0, i].axis('off') if show_titles: ax[0, i].set_title(title) # Mask if 'mask' in sample: kwargs: dict[str, Any] = {'cmap': self.cmap} if hasattr(self, 'classes'): # Semantic segmentation kwargs |= { 'vmin': 0, 'vmax': len(self.classes) - 1, 'interpolation': 'none', } mask = sample['mask'] ax[0, i + 1].imshow(mask, **kwargs) ax[0, i + 1].axis('off') if show_titles: ax[0, i + 1].set_title('Mask') if 'prediction' in sample: prediction = sample['prediction'] ax[0, i + 2].imshow(prediction, **kwargs) ax[0, i + 2].axis('off') if show_titles: ax[0, i + 2].set_title('Prediction') if suptitle is not None: fig.suptitle(suptitle) fig.tight_layout() return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources