Shortcuts

Source code for torchgeo.datasets.copernicus.lcz_s2

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""Copernicus-Bench LCZ-S2 dataset."""

import os
from collections.abc import Callable, Sequence
from typing import ClassVar, Literal

import torch
from einops import rearrange
from torch import Tensor

from ..errors import DatasetNotFoundError
from ..utils import Path, download_url, lazy_import
from .base import CopernicusBenchBase


[docs]class CopernicusBenchLCZS2(CopernicusBenchBase): """Copernicus-Bench LCZ-S2 dataset. LCZ-S2 is a multi-class scene classification dataset derived from So2Sat-LCZ42, a large-scale local climate zone classification dataset. If you use this dataset in your research, please cite the following papers: * https://arxiv.org/abs/2503.11849 * https://doi.org/10.1109/MGRS.2020.2964708 .. note:: This dataset requires the following additional library to be installed: * `<https://pypi.org/project/h5py/>`_ to load the dataset. .. versionadded:: 0.7 """ url = 'https://hf.co/datasets/wangyi111/Copernicus-Bench/resolve/9d252acd3aa0e3da3128e05c6f028647f0e48e5f/l3_lcz_s2/lcz_{}.h5' md5s: ClassVar[dict[str, str]] = { 'train': 'e0b10cdb7f12e053cda8dd3ff12dbd9e', 'val': 'be3b503dba5a1405ec6d5a770c2bee33', 'test': '4e95788c72a421d636f6f8dc7623d116', } filename = 'lcz_{}.h5' all_bands = ('B02', 'B03', 'B04', 'B05', 'B06', 'B07', 'B08', 'B8A', 'B11', 'B12') rgb_bands = ('B04', 'B03', 'B02') classes = ( 'Compact high rise', 'Compact mid rise', 'Compact low rise', 'Open high rise', 'Open mid rise', 'Open low rise', 'Lightweight low rise', 'Large low rise', 'Sparsely built', 'Heavy industry', 'Dense trees', 'Scattered trees', 'Bush, scrub', 'Low plants', 'Bare rock or paved', 'Bare soil or sand', 'Water', )
[docs] def __init__( self, root: Path = 'data', split: Literal['train', 'val', 'test'] = 'train', bands: Sequence[str] | None = None, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize a new CopernicusBenchBase instance. Args: root: Root directory where dataset can be found. split: One of 'train', 'val', or 'test'. bands: Sequence of band names to load (defaults to all bands). transforms: A function/transform that takes input sample and its target as entry and returns a transformed version. download: If True, download dataset and store it in the root directory. checksum: If True, check the MD5 of the downloaded files (may be slow). Raises: DatasetNotFoundError: If dataset is not found and *download* is False. """ h5py = lazy_import('h5py') self.root = root self.split = split self.bands = bands or self.all_bands self.band_indices = [self.all_bands.index(i) for i in self.bands] self.transforms = transforms self.download = download self.checksum = checksum self._verify() self.filepath = os.path.join(root, self.filename.format(split)) with h5py.File(self.filepath, 'r') as f: self.length: int = f['label'].shape[0]
[docs] def __len__(self) -> int: """Return the length of the dataset. Returns: Length of the dataset. """ return self.length
[docs] def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: index: Index to return. Returns: Data and labels at that index. """ h5py = lazy_import('h5py') with h5py.File(self.filepath, 'r') as f: sen2 = f['sen2'][index][:, :, self.band_indices] sen2 = rearrange(sen2, 'h w c -> c h w') label = f['label'][index].argmax() sample = {'image': torch.from_numpy(sen2), 'label': torch.tensor(label)} if self.transforms is not None: sample = self.transforms(sample) return sample
def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the files already exist if os.path.exists(os.path.join(self.root, self.filename.format(self.split))): return # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) # Download and extract the dataset self._download() def _download(self) -> None: """Download the dataset.""" md5 = self.md5s[self.split] if self.checksum else None download_url(self.url.format(self.split), self.root, md5=md5)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources