Shortcuts

Source code for torchgeo.datasets.mmflood

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""MMFlood dataset."""

from __future__ import annotations

import os
from collections.abc import Callable
from glob import glob
from typing import ClassVar, Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.figure import Figure
from pyproj import CRS
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import IntersectionDataset, RasterDataset
from .utils import GeoSlice, Path, download_url, extract_archive


class MMFloodComponent(RasterDataset):
    """Base component for MMFlood dataset."""

    def __init__(
        self,
        subfolders: list[str],
        content: Literal['s1_raw', 'DEM', 'hydro', 'mask'],
        root: Path = 'data',
        crs: CRS | None = None,
        res: float | tuple[float, float] | None = None,
        transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
        cache: bool = False,
    ) -> None:
        """Initialize MMFloodComponent dataset instance.

        Args:
            subfolders: list of directories to be loaded
            content: specifies which component to load
            root: root directory where dataset can be found
            crs: :term:`coordinate reference system (CRS)` to warp to
                (defaults to the CRS of the first file found)
            res: resolution of the dataset in units of CRS in (xres, yres) format. If a
                single float is provided, it is used for both the x and y resolution.
                (defaults to the resolution of the first file found)
            transforms: a function/transform that takes input sample and its target as
                entry and returns a transformed version
            cache: if True, cache file handle to speed up repeated sampling
        """
        self.content = content
        self.is_image = content != 'mask'
        paths = []
        for s in subfolders:
            paths += glob(os.path.join(root, '**', f'{s}*-*', self.content, '*.tif'))
        paths = sorted(paths)
        super().__init__(paths, crs, res, transforms=transforms, cache=cache)


[docs]class MMFlood(IntersectionDataset): """MMFlood dataset. `MMFlood <https://huggingface.co/datasets/links-ads/mmflood>`__ dataset is a multimodal flood delineation dataset. Sentinel-1 data is matched with masks and DEM data for all available tiles. If hydrography maps are loaded, only a subset of the dataset is loaded, since only 1,012 Sentinel-1 tiles have a corresponding hydrography map. Some Sentinel-1 tiles have missing data, which are automatically set to 0. Corresponding pixels in masks are set to 255 and should be ignored in performance computation. Dataset features: * 1,748 Sentinel-1 tiles of varying pixel dimensions * multimodal dataset * 95 flood events from 42 different countries * includes DEMs * includes hydrography maps (available for 1,012 tiles out of 1,748) * flood delineation maps (ground truth) is obtained from Copernicus EMS Dataset classes: 0. no flood 1. flood If you use this dataset in your research, please cite the following paper: * https://doi.org/10.1109/ACCESS.2022.3205419 .. versionadded:: 0.7 """ url = 'https://huggingface.co/datasets/links-ads/mmflood/resolve/24ca097306c9e50ad0711903c11e1ba13ea1bedc/' _ignore_index = 255 _nparts = 11 metadata: ClassVar[dict[str, str]] = { 'part_file': 'activations.tar.{part}.gz.part', 'filename': 'activations.tar.gz', 'directory': 'activations', 'metadata_file': 'activations.json', } _splits: ClassVar[set[str]] = {'train', 'val', 'test'} _md5: ClassVar[dict[str, str]] = { 'activations.json': 'de33a3ac7e55a0051ada21cbdfbb4745', 'activations.tar.gz': '3cd4c4fe7506aa40263f74639d85ccce', 'activations.tar.000.gz.part': 'a8424653edca6e79999831bdda53d4dc', 'activations.tar.001.gz.part': '517def8760d3ce86885c7600c77a1d6c', 'activations.tar.002.gz.part': '6797b97121f5b98ff58fde7491f584b2', 'activations.tar.003.gz.part': 'e69d2a6b1746ef869d1da4d22018a71a', 'activations.tar.004.gz.part': '0ccf7ea69ea6c0e88db1b1015ec3361e', 'activations.tar.005.gz.part': '8ef6765afe20f254b1e752d7a2742fda', 'activations.tar.006.gz.part': '3f330a44b66511b7a95f4a555f8b793a', 'activations.tar.007.gz.part': '1d2046b5f3c473c3681a05dc94b29b86', 'activations.tar.008.gz.part': 'f386b5acf78f8ae34592404c6c7ec43c', 'activations.tar.009.gz.part': 'dd5317a3c0d33de815beadb9850baa38', 'activations.tar.010.gz.part': '5a14a7e3f916c5dcf288c2ca88daf4d0', }
[docs] def __init__( self, root: Path = 'data', crs: CRS | None = None, res: float | tuple[float, float] | None = None, split: str = 'train', include_dem: bool = False, include_hydro: bool = False, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, cache: bool = False, ) -> None: """Initialize a new MMFlood dataset instance. Args: root: root directory where dataset can be found crs: :term:`coordinate reference system (CRS)` to warp to (defaults to the CRS of the first file found) res: resolution of the dataset in units of CRS in (xres, yres) format. If a single float is provided, it is used for both the x and y resolution. (defaults to the resolution of the first file found) split: train/val/test split to load include_dem: If True, DEM data is concatenated after Sentinel-1 bands. include_hydro: If True, hydrography data is concatenated as last channel. Only a smaller subset of the original dataset is loaded in this case. transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) cache: if True, cache file handle to speed up repeated sampling Raises: DatasetNotFoundError: If dataset is not found and *download* is False. AssertionError: If *split* is invalid. """ assert split in self._splits self.root = root self.split = split self.include_dem = include_dem self.include_hydro = include_hydro self.transforms = transforms self.download = download self.checksum = checksum # Verify integrity of the dataset self._verify() self.metadata_df = pd.read_json( os.path.join(self.root, self.metadata['metadata_file']) ).transpose() split_subfolders = self.metadata_df[ self.metadata_df['subset'] == self.split ].index.tolist() self.image: MMFloodComponent | IntersectionDataset = MMFloodComponent( split_subfolders, 's1_raw', root, crs, res, cache=cache ) if include_dem: dem = MMFloodComponent(split_subfolders, 'DEM', root, crs, res, cache=cache) self.image = self.image & dem self.image.index = dem.index if include_hydro: hydro = MMFloodComponent( split_subfolders, 'hydro', root, crs, res, cache=cache ) self.image = self.image & hydro self.image.index = hydro.index self.mask = MMFloodComponent( split_subfolders, 'mask', root, crs, res, cache=cache ) super().__init__(self.image, self.mask, transforms=transforms) self.index = self.image.index
def _merge_tar_files(self) -> None: """Merge part tar gz files.""" dst_filename = self.metadata['filename'] dst_path = os.path.join(self.root, dst_filename) print('Merging separate part files...') with open(dst_path, 'wb') as dst_fp: for idx in range(self._nparts): part_filename = f'activations.tar.{idx:03}.gz.part' part_path = os.path.join(self.root, part_filename) print(f'Processing file {part_path!s}') with open(part_path, 'rb') as part_fp: dst_fp.write(part_fp.read())
[docs] def __getitem__(self, query: GeoSlice) -> dict[str, Tensor]: """Retrieve input, target, and/or metadata indexed by spatiotemporal slice. Args: query: [xmin:xmax:xres, ymin:ymax:yres, tmin:tmax:tres] coordinates to index. Returns: Sample of input, target, and/or metadata at that index. Raises: IndexError: If *query* is not found in the index. """ data = super().__getitem__(query) missing_data = data['image'].isnan().any(dim=0) # Set all pixel values of invalid areas to 0, all mask values to 255 data['image'][:, missing_data] = 0 data['mask'][missing_data] = self._ignore_index return data
def _download(self) -> None: """Download the dataset.""" def _check_and_download(filename: str, url: str) -> None: path = os.path.join(self.root, filename) if not os.path.exists(path): md5 = self._md5[filename] if self.checksum else None download_url(url, self.root, filename, md5) return filename = self.metadata['filename'] filepath = os.path.join(self.root, filename) if not os.path.exists(filepath): for idx in range(self._nparts): part_file = f'activations.tar.{idx:03}.gz.part' url = self.url + part_file _check_and_download(part_file, url) _check_and_download( self.metadata['metadata_file'], self.url + self.metadata['metadata_file'] ) def _extract(self) -> None: """Extract the dataset.""" filepath = os.path.join(self.root, self.metadata['filename']) if str(filepath).endswith('.tar.gz'): extract_archive(filepath) def _verify(self) -> None: """Verify the integrity of the dataset.""" dirpath = os.path.join(self.root, self.metadata['directory']) metadata_filepath = os.path.join(self.root, self.metadata['metadata_file']) # Check if both metadata file and directory exist if os.path.isdir(dirpath) and os.path.isfile(metadata_filepath): return if not self.download: raise DatasetNotFoundError(self) self._download() self._merge_tar_files() self._extract()
[docs] def plot( self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. Args: sample: a sample returned by :meth:`__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional suptitle to use for figure Returns: a matplotlib Figure with the rendered sample """ show_mask = 'mask' in sample image = sample['image'][[0, 1]].permute(1, 2, 0).numpy() ncols = 1 show_predictions = 'prediction' in sample if self.include_dem: dem_idx = -2 if self.include_hydro else -1 dem = sample['image'][dem_idx].squeeze(0).numpy() ncols += 1 if self.include_hydro: hydro = sample['image'][-1].squeeze(0).numpy() ncols += 1 if show_mask: mask = sample['mask'].numpy() # Set ignore_index values to 0 mask[mask == self._ignore_index] = 0 ncols += 1 if show_predictions: pred = sample['prediction'].numpy() ncols += 1 # Compute False Color image, from Sentinel1 plot function co_polarization = image[..., 0] # transmit == receive cross_polarization = image[..., 1] # transmit != receive ratio = co_polarization / cross_polarization # https://gis.stackexchange.com/a/400780/123758 co_polarization = np.clip(co_polarization / 0.3, a_min=0, a_max=1) cross_polarization = np.clip(cross_polarization / 0.05, a_min=0, a_max=1) ratio = np.clip(ratio / 25, a_min=0, a_max=1) image = np.stack((co_polarization, cross_polarization, ratio), axis=-1) # Generate the figure fig, axs = plt.subplots(ncols=ncols, figsize=(4 * ncols, 4)) axs[0].imshow(image) axs[0].axis('off') axs_idx = 1 if self.include_dem: axs[axs_idx].imshow(dem, cmap='gray') axs[axs_idx].axis('off') axs_idx += 1 if self.include_hydro: axs[axs_idx].imshow(hydro, cmap='gray') axs[axs_idx].axis('off') axs_idx += 1 if show_mask: axs[axs_idx].imshow(mask, cmap='gray') axs[axs_idx].axis('off') axs_idx += 1 if show_predictions: axs[axs_idx].imshow(pred, cmap='gray') axs[axs_idx].axis('off') if show_titles: axs[0].set_title('Image') axs_idx = 1 if self.include_dem: axs[axs_idx].set_title('DEM') axs_idx += 1 if self.include_hydro: axs[axs_idx].set_title('Hydrography Map') axs_idx += 1 if show_mask: axs[axs_idx].set_title('Mask') axs_idx += 1 if show_predictions: axs[axs_idx].set_title('Prediction') if suptitle is not None: plt.suptitle(suptitle) return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources