Shortcuts

Source code for torchgeo.datasets.mmearth

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""MMEarth Dataset."""

import json
import os
from collections.abc import Callable, Sequence
from datetime import datetime, timedelta
from typing import Any, ClassVar, cast

import matplotlib.pyplot as plt
import numpy as np
import torch
from einops import rearrange
from matplotlib.figure import Figure
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, lazy_import, percentile_normalization


[docs]class MMEarth(NonGeoDataset): """MMEarth dataset. There are three different versions of the dataset, that vary in image size and the number of tiles: * MMEarth: 128x128 px, 1.2M tiles, 579 GB * MMEarth64: 64x64 px, 1.2M tiles, 162 GB * MMEarth100k: 128x128 px, 100K tiles, 48 GB The dataset consists of 12 modalities: * Aster: elevation and slope * Biome: 14 terrestrial ecosystem categories * ETH Canopy Height: Canopy height and standard deviation * Dynamic World: 9 landcover categories * Ecoregion: 846 ecoregion categories * ERA5: Climate reanalysis data for temperature mean, min, and max of [year, month, previous month] and precipitation total of [year, month, previous month] (counted as separate modalities) * ESA World Cover: 11 landcover categories * Sentinel-1: VV, VH, HV, HH for ascending/descending orbit * Sentinel-2: multi-spectral B1-B12 for L1C/L2A products * Geolocation: cyclic encoding of latitude and longitude * Date: cyclic encoding of month Additionally, there are three masks available as modalities: * Sentinel-2 Cloudmask: Sentinel-2 cloud mask * Sentinel-2 Cloud probability: Sentinel-2 cloud probability * Sentinel-2 SCL: Sentinel-2 scene classification that are synchronized across tiles. Dataset format: * Dataset in single HDF5 file * JSON files for band statistics, splits, and tile information For additional information, as well as bash scripts to download the data, please refer to the `official repository <https://github.com/vishalned/MMEarth-data?tab=readme-ov-file#data-download>`_. If you use this dataset in your research, please cite the following paper: * https://arxiv.org/abs/2405.02771 .. note:: This dataset requires the following additional library to be installed: * `h5py <https://pypi.org/project/h5py/>`_ to load the dataset .. versionadded:: 0.7 """ subsets = ('MMEarth', 'MMEarth64', 'MMEarth100k') filenames: ClassVar[dict[str, str]] = { 'MMEarth': 'data_1M_v001', 'MMEarth64': 'data_1M_v001_64', 'MMEarth100k': 'data_100k_v001', } all_modalities = ( 'aster', 'biome', 'canopy_height_eth', 'dynamic_world', 'eco_region', 'era5', 'esa_worldcover', 'sentinel1_asc', 'sentinel1_desc', 'sentinel2', 'sentinel2_cloudmask', 'sentinel2_cloudprod', 'sentinel2_scl', ) # See https://github.com/vishalned/MMEarth-train/blob/8d6114e8e3ccb5ca5d98858e742dac24350b64fd/MODALITIES.py#L108C1-L160C2 all_modality_bands: ClassVar[dict[str, list[str]]] = { 'sentinel2': [ 'B1', 'B2', 'B3', 'B4', 'B5', 'B6', 'B7', 'B8A', 'B8', 'B9', 'B10', 'B11', 'B12', ], 'sentinel2_cloudmask': ['QA60'], 'sentinel2_cloudprod': ['MSK_CLDPRB'], 'sentinel2_scl': ['SCL'], 'sentinel1_asc': ['VV', 'VH', 'HH', 'HV'], 'sentinel1_desc': ['VV', 'VH', 'HH', 'HV'], 'aster': ['b1', 'slope'], # elevation and slope 'era5': [ 'prev_temperature_2m', # previous month avg temp 'prev_temperature_2m_min', # previous month min temp 'prev_temperature_2m_max', # previous month max temp 'prev_total_precipitation_sum', # previous month total precip 'curr_temperature_2m', # current month avg temp 'curr_temperature_2m_min', # current month min temp 'curr_temperature_2m_max', # current month max temp 'curr_total_precipitation_sum', # current month total precip '0_temperature_2m_mean', # year avg temp '1_temperature_2m_min_min', # year min temp '2_temperature_2m_max_max', # year max temp '3_total_precipitation_sum_sum', # year total precip ], 'dynamic_world': ['label'], 'canopy_height_eth': ['height', 'std'], 'lat': ['sin', 'cos'], 'lon': ['sin', 'cos'], 'biome': ['biome'], 'eco_region': ['eco_region'], 'month': ['sin_month', 'cos_month'], 'esa_worldcover': ['Map'], } # See https://github.com/vishalned/MMEarth-train/blob/8d6114e8e3ccb5ca5d98858e742dac24350b64fd/MODALITIES.py#L36 no_data_vals: ClassVar[dict[str, int | float]] = { 'sentinel2': 0, 'sentinel2_cloudmask': 65535, 'sentinel2_cloudprod': 65535, 'sentinel2_scl': 255, 'sentinel1_asc': float('-inf'), 'sentinel1_desc': float('-inf'), 'aster': float('-inf'), 'canopy_height_eth': 255, 'dynamic_world': 0, 'esa_worldcover': 255, 'lat': float('-inf'), 'lon': float('-inf'), 'month': float('-inf'), 'era5': float('inf'), 'biome': 255, 'eco_region': 65535, } norm_modes = ('z-score', 'min-max') modality_category_name: ClassVar[dict[str, str]] = { 'sentinel1_asc': 'image_', 'sentinel1_desc': 'image_', 'sentinel2': 'image_', 'sentinel2_cloudmask': 'mask_', 'sentinel2_cloudprod': 'mask_', 'sentinel2_scl': 'mask_', 'aster': 'image_', 'era5': '', 'canopy_height_eth': 'image_', 'dynamic_world': 'mask_', 'esa_worldcover': 'mask_', }
[docs] def __init__( self, root: Path = 'data', subset: str = 'MMEarth', modalities: Sequence[str] = all_modalities, modality_bands: dict[str, list[str]] | None = None, normalization_mode: str = 'z-score', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, ) -> None: """Initialize the MMEarth dataset. Args: root: root directory where dataset can be found subset: one of "MMEarth", "MMEarth64", or "MMEarth100k" modalities: list of modalities to load modality_bands: dictionary of modality bands, see normalization_mode: one of "z-score" or "min-max" transforms: a function/transform that takes input sample dictionary and returns a transformed version Raises: AssertionError: if *normalization_mode* or *subset* DatasetNotFoundError: If dataset is not found and *download* is False. """ lazy_import('h5py') assert normalization_mode in self.norm_modes, ( f'Invalid normalization mode: {normalization_mode}, please choose from {self.norm_modes}' ) assert subset in self.subsets, ( f'Invalid dataset version: {subset}, please choose from {self.subsets}' ) self._validate_modalities(modalities) self.modalities = modalities if modality_bands is None: modality_bands = { modality: self.all_modality_bands[modality] for modality in modalities } self._validate_modality_bands(modality_bands) self.modality_bands = modality_bands self.root = root self.subset = subset self.normalization_mode = normalization_mode self.split = 'train' self.transforms = transforms self.dataset_filename = f'{self.filenames[subset]}.h5' self.band_stats_filename = f'{self.filenames[subset]}_band_stats.json' self.splits_filename = f'{self.filenames[subset]}_splits.json' self.tile_info_filename = f'{self.filenames[subset]}_tile_info.json' self._verify() self.indices = self._load_indices() self.band_stats = self._load_normalization_stats() self.tile_info = self._load_tile_info()
def _verify(self) -> None: """Verify the dataset.""" data_dir = os.path.join(self.root, self.filenames[self.subset]) exists = [ os.path.exists(os.path.join(data_dir, f)) for f in [ self.dataset_filename, self.band_stats_filename, self.splits_filename, self.tile_info_filename, ] ] if not all(exists): raise DatasetNotFoundError(self) def _load_indices(self) -> list[int]: """Load the indices for the dataset split. Returns: list of indices """ with open( os.path.join(self.root, self.filenames[self.subset], self.splits_filename) ) as f: split_indices: dict[str, list[int]] = json.load(f) return split_indices[self.split] def _load_normalization_stats(self) -> dict[str, dict[str, float]]: """Load normalization statistics for each band. Returns: dictionary containing the normalization statistics """ with open( os.path.join( self.root, self.filenames[self.subset], self.band_stats_filename ) ) as f: band_stats = json.load(f) return cast(dict[str, dict[str, float]], band_stats) def _load_tile_info(self) -> dict[str, dict[str, str]]: """Load tile information. Returns: dictionary containing tile information """ with open( os.path.join( self.root, self.filenames[self.subset], self.tile_info_filename ) ) as f: tile_info = json.load(f) return cast(dict[str, dict[str, str]], tile_info) def _validate_modalities(self, modalities: Sequence[str]) -> None: """Validate list of modalities. Args: modalities: user-provided sequence of modalities to load Raises: AssertionError: if ``modalities`` is not a sequence or an invalid modality name is provided """ # validate modalities assert isinstance(modalities, Sequence), "'modalities' must be a sequence" if not set(modalities) <= set(self.all_modalities): raise ValueError( f'{set(modalities) - set(self.all_modalities)} is an invalid modality.' ) def _validate_modality_bands(self, modality_bands: dict[str, list[str]]) -> None: """Validate modality bands. Args: modality_bands: user-provided dictionary of modality bands Raises: AssertionError: if ``modality_bands`` is not a dictionary ValueError: if an invalid modality name is provided ValueError: if modality bands are invalid """ assert isinstance(modality_bands, dict), "'modality_bands' must be a dictionary" # validate modality bands for key, vals in modality_bands.items(): # check that the modality name is also specified in modalities if key not in self.modalities: raise ValueError(f"'{key}' is an invalid modality name.") for val in vals: if val not in self.all_modality_bands[key]: raise ValueError( f"'{val}' is an invalid band name for modality '{key}'." )
[docs] def __getitem__(self, index: int) -> dict[str, Any]: """Return a sample from the dataset. Normalization is applied to the data with chosen ``normalization_mode``. In addition to the modalities, the sample contains the following raw metadata: * lat: latitude * lon: longitude * date: date * crs: coordinate reference system * tile_id: tile identifier Args: index: index to return Returns: dictionary containing the modalities and metadata of the sample """ ds_index = self.indices[index] # expose sample retrieval to separate function to allow for different index sampling strategies # in subclasses sample = self._retrieve_sample(ds_index) if self.transforms is not None: sample = self.transforms(sample) return sample
[docs] def get_sample_specific_band_names( self, tile_info: dict[str, Any] ) -> dict[str, list[str]]: """Retrieve the sample specific band names. Args: tile_info: tile information for a sample Returns: dictionary containing the specific band names for each modality """ date_str = tile_info['S2_DATE'] date_obj = datetime.strptime(date_str, '%Y-%m-%d') curr_month_str = date_obj.strftime('%Y%m') # set to first day of month and subtract one day to get previous month prev_month_obj = date_obj.replace(day=1) - timedelta(days=1) prev_month_str = prev_month_obj.strftime('%Y%m') specific_modality_bands = {} for modality, bands in self.modality_bands.items(): if modality == 'era5': # replace date with the 'prev' and 'curr' strings for generality bands = [band.replace(prev_month_str, 'prev') for band in bands] bands = [band.replace(curr_month_str, 'curr') for band in bands] specific_modality_bands[modality] = bands return specific_modality_bands
[docs] def get_intersection_dict(self, tile_info: dict[str, Any]) -> dict[str, list[str]]: """Get intersection of requested and available bands. Args: tile_info: tile information for a sample Returns: Dictionary with intersected keys and lists. """ sample_specific_band_names = self.get_sample_specific_band_names(tile_info) # used the chosen modality bands to get the intersection with available bands intersection_dict = {} for modality in self.all_modalities: if modality in sample_specific_band_names: intersected_list = [ band for band in self.all_modality_bands[modality] if band in sample_specific_band_names[modality] ] if intersected_list: intersection_dict[modality] = intersected_list return intersection_dict
def _retrieve_sample(self, ds_index: int) -> dict[str, Any]: """Retrieve a sample from the dataset. Args: ds_index: index inside the hdf5 dataset file Returns: dictionary containing the modalities and metadata of the sample """ h5py = lazy_import('h5py') sample: dict[str, Any] = {} with h5py.File( os.path.join(self.root, self.filenames[self.subset], self.dataset_filename), 'r', ) as f: name = f['metadata'][ds_index][0].decode('utf-8') tile_info: dict[str, Any] = self.tile_info[name] # need to find the intersection of requested and available bands intersection_dict = self.get_intersection_dict(tile_info) for modality, bands in intersection_dict.items(): if 'sentinel1' in modality: data = f['sentinel1'][ds_index][:] else: data = f[modality][ds_index][:] tensor = self._preprocess_modality(data, modality, tile_info, bands) modality_name = self.modality_category_name.get(modality, '') + modality sample[modality_name] = tensor # add the sensor and bands actually available sample['avail_bands'] = intersection_dict # add additional metadata to the sample sample['lat'] = tile_info['lat'] sample['lon'] = tile_info['lon'] sample['date'] = tile_info['S2_DATE'] sample['crs'] = tile_info['CRS'] sample['tile_id'] = name return sample def _select_indices_for_modality( self, modality: str, bands: list[str] ) -> list[int]: """Select bands for a modality. Args: modality: modality name bands: bands aviailable for the modality Returns: list of band indices """ # need to handle sentinel1 descending separately, because ascending # and descending are stored under the same modality if modality == 'sentinel1_desc': indices = [ self.all_modality_bands['sentinel1_desc'].index(band) + 4 for band in bands ] # the modality is called sentinel2 but has different bands stats for l1c and l2a # but common indices elif modality in ['sentinel2_l1c', 'sentinel2_l2a']: indices = [ self.all_modality_bands['sentinel2'].index(band) for band in bands ] else: indices = [self.all_modality_bands[modality].index(band) for band in bands] return indices def _preprocess_modality( self, data: 'np.typing.NDArray[Any]', modality: str, tile_info: dict[str, Any], bands: list[str], ) -> Tensor: """Preprocess a single modality. Args: data: data to process modality: modality name tile_info: tile information bands: available bands for the modality Returns: processed data """ # band selection for modality indices = self._select_indices_for_modality(modality, bands) data = data[indices, ...] # See https://github.com/vishalned/MMEarth-train/blob/8d6114e8e3ccb5ca5d98858e742dac24350b64fd/mmearth_dataset.py#L69 if modality == 'dynamic_world': # first replace 0 with nan then assign new labels to have 0-index classes data = np.where(data == self.no_data_vals[modality], np.nan, data) old_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, np.nan] new_values = [0, 1, 2, 3, 4, 5, 6, 7, 8, np.nan] for old, new in zip(old_values, new_values): data = np.where(data == old, new, data) # need to replace nan with a no-data value and get long tensor # maybe also 255 like esa_worldcover tensor = torch.from_numpy(data) elif modality == 'esa_worldcover': old_values = [10, 20, 30, 40, 50, 60, 70, 80, 90, 95, 100, 255] new_values = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 255] for old, new in zip(old_values, new_values): data = np.where(data == old, new, data) # currently no-data value is still 255 tensor = torch.from_numpy(data).long() elif modality in [ 'aster', 'canopy_height_eth', 'sentinel1_asc', 'sentinel1_desc', 'sentinel2', 'era5', 'lat', 'lon', 'month', ]: data = data.astype(np.float32) # See https://github.com/vishalned/MMEarth-train/blob/8d6114e8e3ccb5ca5d98858e742dac24350b64fd/mmearth_dataset.py#L88 # the modality is called sentinel2 but has different bands stats for l1c and l2a if modality == 'sentinel2': modality_ = ( 'sentinel2_l2a' if tile_info['S2_type'] == 'l2a' else 'sentinel2_l1c' ) else: modality_ = modality data = self._normalize_modality(data, modality_, bands) data = np.where(data == self.no_data_vals[modality], np.nan, data) tensor = torch.from_numpy(data).float() elif modality in ['biome', 'eco_region']: data = data.astype(np.int32) # no data value also 255 for biome and 65535 for eco_region tensor = torch.from_numpy(data).long() elif modality in [ 'sentinel2_cloudmask', 'sentinel2_cloudprod', 'sentinel2_scl', ]: tensor = torch.from_numpy(data.astype(np.int32)).long() # TODO: tensor might still contain nans, how to handle this? return tensor def _normalize_modality( self, data: 'np.typing.NDArray[Any]', modality: str, bands: list[str] ) -> 'np.typing.NDArray[np.float64]': """Normalize a single modality. Args: data: data to normalize modality: modality name bands: available bands for the modality Returns: normalized data """ indices = self._select_indices_for_modality(modality, bands) if 'sentinel1' in modality: modality = 'sentinel1' if self.normalization_mode == 'z-score': mean = np.array(self.band_stats[modality]['mean'])[indices, ...] std = np.array(self.band_stats[modality]['std'])[indices, ...] if data.ndim == 3: data = (data - mean[:, None, None]) / std[:, None, None] else: data = (data - mean) / std elif self.normalization_mode == 'min-max': min_val = np.array(self.band_stats[modality]['min'])[indices, ...] max_val = np.array(self.band_stats[modality]['max'])[indices, ...] if data.ndim == 3: data = (data - min_val[:, None, None]) / ( max_val[:, None, None] - min_val[:, None, None] ) else: data = (data - min_val) / (max_val - min_val) return data
[docs] def __len__(self) -> int: """Return the length of the dataset. Returns: length of the dataset """ return len(self.indices)
[docs] def plot( self, sample: dict[str, Any], show_titles: bool = True, suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset as shown in fig. 2 from https://arxiv.org/pdf/2405.02771. Args: sample: A sample returned by :meth:`__getitem__`. show_titles: Flag indicating whether to show titles above each panel. suptitle: Optional string to use as a suptitle. Returns: A matplotlib Figure with the rendered sample. """ color_map = { 'esa_worldcover': { 0: [0, 100, 0], # Tree cover 1: [255, 187, 34], # Shrubland 2: [255, 255, 76], # Grassland 3: [240, 150, 255], # Cropland 4: [250, 0, 0], # Built-up 5: [180, 180, 180], # Bare/sparse vegetation 6: [240, 240, 240], # Snow and Ice 7: [0, 100, 200], # Permanent water bodies 8: [0, 150, 160], # Herbaceous wetland 9: [0, 207, 117], # Mangroves 10: [250, 230, 160], # Moss and lichen 255: [0, 0, 0], # No-data value }, 'dynamic_world': { 0: [65, 155, 223], # #419BDF - Water 1: [57, 125, 73], # #397D49 - Trees 2: [136, 176, 83], # #88B053 - Grass 3: [122, 135, 198], # #7A87C6 - Flooded vegetation 4: [228, 150, 53], # #E49635 - Crops 5: [223, 195, 90], # #DFC35A - Shrub & Scrub 6: [196, 40, 27], # #C4281B - Built Area 7: [165, 155, 143], # #A59B8F - Bare ground 8: [179, 159, 225], # #B39FE1 - Snow & Ice }, } images = [] titles = [] keys_to_plot = [ 'image_sentinel2', 'image_sentinel1_asc', 'image_aster', 'mask_esa_worldcover', 'mask_dynamic_world', 'image_canopy_height_eth', ] avail_bands_dict = dict(sample['avail_bands']) for key in keys_to_plot: val = sample[key] modalities_name = key.split('_', 1)[1] match modalities_name: case 'sentinel2': norm_img = percentile_normalization(val[[3, 2, 1]].numpy()) images.append(rearrange(norm_img, 'c h w -> h w c')) titles.append('Sentinel-2 RGB') case 'esa_worldcover': tensor_np = val.squeeze().numpy() rgb_image = np.zeros( (tensor_np.shape[0], tensor_np.shape[1], 3), dtype=np.uint8 ) for value, color in color_map[modalities_name].items(): mask = tensor_np == value rgb_image[mask] = color images.append(rgb_image) titles.append(modalities_name.replace('_', ' ').title()) case 'dynamic_world': tensor_np = val.squeeze().numpy() rgb_image = np.zeros( (tensor_np.shape[0], tensor_np.shape[1], 3), dtype=np.uint8 ) for value, color in color_map[modalities_name].items(): mask = tensor_np == value rgb_image[mask] = color images.append(rgb_image) titles.append(modalities_name.replace('_', ' ').title()) case _: band_val = val[0].numpy() norm_img = percentile_normalization(band_val) images.append(norm_img) modalities_name = key.split('_', 1)[1] band_name = avail_bands_dict[modalities_name][0] titles.append( (modalities_name.replace('_', ' ').title()) + ' ' + band_name ) fig, ax = plt.subplots(1, 6, figsize=(12, 4)) for i, (image, title) in enumerate(zip(images, titles)): ax[i].imshow(image) ax[i].axis('off') if show_titles: title_words = title.split(' ') title_word_len = len(title_words) if title_word_len > 2: title = ( str.join(' ', title_words[:2]) + '\n' + str.join(' ', title_words[2:]) ) ax[i].set_title(title) if suptitle is not None: plt.suptitle(suptitle) plt.tight_layout() return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources