Shortcuts

Source code for torchgeo.datasets.satlas

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""SatlasPretrain dataset."""

import os
from collections.abc import Callable, Iterable
from typing import ClassVar, TypedDict

import numpy as np
import pandas as pd
import torch
from einops import rearrange
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, check_integrity, extract_archive, which


class _Task(TypedDict, total=False):
    BackgroundInvalid: bool
    categories: list[str]
    colors: list[list[int]]
    type: str


# https://github.com/allenai/satlas/blob/main/satlas/model/dataset.py
TASKS: dict[str, _Task] = {
    'polyline_bin_segment': {
        'type': 'bin_segment',
        'categories': [
            'airport_runway',
            'airport_taxiway',
            'raceway',
            'road',
            'railway',
            'river',
        ],
        'colors': [
            [255, 255, 255],  # (white) airport_runway
            [192, 192, 192],  # (light grey) airport_taxiway
            [160, 82, 45],  # (sienna) raceway
            [255, 255, 255],  # (white) road
            [144, 238, 144],  # (light green) railway
            [0, 0, 255],  # (blue) river
        ],
    },
    'bin_segment': {
        'type': 'bin_segment',
        'categories': [
            'aquafarm',
            'lock',
            'dam',
            'solar_farm',
            'power_plant',
            'gas_station',
            'park',
            'parking_garage',
            'parking_lot',
            'landfill',
            'quarry',
            'stadium',
            'airport',
            'airport_runway',
            'airport_taxiway',
            'airport_apron',
            'airport_hangar',
            'airstrip',
            'airport_terminal',
            'ski_resort',
            'theme_park',
            'storage_tank',
            'silo',
            'track',
            'raceway',
            'wastewater_plant',
            'road',
            'railway',
            'river',
            'water_park',
            'pier',
            'water_tower',
            'street_lamp',
            'traffic_signals',
            'power_tower',
            'power_substation',
            'building',
            'bridge',
            'road_motorway',
            'road_trunk',
            'road_primary',
            'road_secondary',
            'road_tertiary',
            'road_residential',
            'road_service',
            'road_track',
            'road_pedestrian',
        ],
        'colors': [
            [32, 178, 170],  # (light sea green) aquafarm
            [0, 255, 255],  # (cyan) lock
            [173, 216, 230],  # (light blue) dam
            [255, 0, 255],  # (magenta) solar farm
            [255, 165, 0],  # (orange) power plant
            [128, 128, 0],  # (olive) gas station
            [0, 255, 0],  # (green) park
            [47, 79, 79],  # (dark slate gray) parking garage
            [128, 0, 0],  # (maroon) parking lot
            [165, 42, 42],  # (brown) landfill
            [128, 128, 128],  # (grey) quarry
            [255, 215, 0],  # (gold) stadium
            [255, 105, 180],  # (pink) airport
            [255, 255, 255],  # (white) airport_runway
            [192, 192, 192],  # (light grey) airport_taxiway
            [128, 0, 128],  # (purple) airport_apron
            [0, 128, 0],  # (dark green) airport_hangar
            [248, 248, 255],  # (ghost white) airstrip
            [240, 230, 140],  # (khaki) airport_terminal
            [192, 192, 192],  # (silver) ski_resort
            [0, 96, 0],  # (dark green) theme_park
            [95, 158, 160],  # (cadet blue) storage_tank
            [205, 133, 63],  # (peru) silo
            [154, 205, 50],  # (yellow green) track
            [160, 82, 45],  # (sienna) raceway
            [218, 112, 214],  # (orchid) wastewater_plant
            [255, 255, 255],  # (white) road
            [144, 238, 144],  # (light green) railway
            [0, 0, 255],  # (blue) river
            [255, 240, 245],  # (lavender blush) water_park
            [65, 105, 225],  # (royal blue) pier
            [238, 130, 238],  # (violet) water_tower
            [75, 0, 130],  # (indigo) street_lamp
            [233, 150, 122],  # (dark salmon) traffic_signals
            [255, 255, 0],  # (yellow) power_tower
            [255, 255, 0],  # (yellow) power_substation
            [255, 0, 0],  # (red) building
            [64, 64, 64],  # (dark grey) bridge
            [255, 255, 255],  # (white) road_motorway
            [255, 255, 255],  # (white) road_trunk
            [255, 255, 255],  # (white) road_primary
            [255, 255, 255],  # (white) road_secondary
            [255, 255, 255],  # (white) road_tertiary
            [255, 255, 255],  # (white) road_residential
            [255, 255, 255],  # (white) road_service
            [255, 255, 255],  # (white) road_track
            [255, 255, 255],  # (white) road_pedestrian
        ],
    },
    'land_cover': {
        'type': 'segment',
        'BackgroundInvalid': True,
        'categories': [
            'background',
            'water',
            'developed',
            'tree',
            'shrub',
            'grass',
            'crop',
            'bare',
            'snow',
            'wetland',
            'mangroves',
            'moss',
        ],
        'colors': [
            [0, 0, 0],  # unknown
            [0, 0, 255],  # (blue) water
            [255, 0, 0],  # (red) developed
            [0, 192, 0],  # (dark green) tree
            [200, 170, 120],  # (brown) shrub
            [0, 255, 0],  # (green) grass
            [255, 255, 0],  # (yellow) crop
            [128, 128, 128],  # (grey) bare
            [255, 255, 255],  # (white) snow
            [0, 255, 255],  # (cyan) wetland
            [255, 0, 255],  # (pink) mangroves
            [128, 0, 128],  # (purple) moss
        ],
    },
    'tree_cover': {'type': 'regress', 'BackgroundInvalid': True},
    'crop_type': {
        'type': 'segment',
        'BackgroundInvalid': True,
        'categories': [
            'invalid',
            'rice',
            'grape',
            'corn',
            'sugarcane',
            'tea',
            'hop',
            'wheat',
            'soy',
            'barley',
            'oats',
            'rye',
            'cassava',
            'potato',
            'sunflower',
            'asparagus',
            'coffee',
        ],
        'colors': [
            [0, 0, 0],  # unknown
            [0, 0, 255],  # (blue) rice
            [255, 0, 0],  # (red) grape
            [255, 255, 0],  # (yellow) corn
            [0, 255, 0],  # (green) sugarcane
            [128, 0, 128],  # (purple) tea
            [255, 0, 255],  # (pink) hop
            [0, 128, 0],  # (dark green) wheat
            [255, 255, 255],  # (white) soy
            [128, 128, 128],  # (grey) barley
            [165, 42, 42],  # (brown) oats
            [0, 255, 255],  # (cyan) rye
            [128, 0, 0],  # (maroon) cassava
            [173, 216, 230],  # (light blue) potato
            [128, 128, 0],  # (olive) sunflower
            [0, 128, 0],  # (dark green) asparagus
            [92, 64, 51],  # (dark brown) coffee
        ],
    },
    'point': {
        'type': 'detect',
        'categories': [
            'background',
            'wind_turbine',
            'lighthouse',
            'mineshaft',
            'aerialway_pylon',
            'helipad',
            'fountain',
            'toll_booth',
            'chimney',
            'communications_tower',
            'flagpole',
            'petroleum_well',
            'water_tower',
            'offshore_wind_turbine',
            'offshore_platform',
            'power_tower',
        ],
        'colors': [
            [0, 0, 0],
            [0, 255, 255],  # (cyan) wind_turbine
            [0, 255, 0],  # (green) lighthouse
            [255, 255, 0],  # (yellow) mineshaft
            [0, 0, 255],  # (blue) pylon
            [173, 216, 230],  # (light blue) helipad
            [128, 0, 128],  # (purple) fountain
            [255, 255, 255],  # (white) toll_booth
            [0, 128, 0],  # (dark green) chimney
            [128, 128, 128],  # (grey) communications_tower
            [165, 42, 42],  # (brown) flagpole
            [128, 0, 0],  # (maroon) petroleum_well
            [255, 165, 0],  # (orange) water_tower
            [255, 255, 0],  # (yellow) offshore_wind_turbine
            [255, 0, 0],  # (red) offshore_platform
            [255, 0, 255],  # (magenta) power_tower
        ],
    },
    'rooftop_solar_panel': {
        'type': 'detect',
        'categories': ['background', 'rooftop_solar_panel'],
        'colors': [
            [0, 0, 0],
            [255, 255, 0],  # (yellow) rooftop_solar_panel
        ],
    },
    'building': {
        'type': 'instance',
        'categories': ['background', 'ms_building'],
        'colors': [
            [0, 0, 0],
            [255, 255, 0],  # (yellow) building
        ],
    },
    'polygon': {
        'type': 'instance',
        'categories': [
            'background',
            'aquafarm',
            'lock',
            'dam',
            'solar_farm',
            'power_plant',
            'gas_station',
            'park',
            'parking_garage',
            'parking_lot',
            'landfill',
            'quarry',
            'stadium',
            'airport',
            'airport_apron',
            'airport_hangar',
            'airport_terminal',
            'ski_resort',
            'theme_park',
            'storage_tank',
            'silo',
            'track',
            'wastewater_plant',
            'power_substation',
            'pier',
            'crop',
            'water_park',
        ],
        'colors': [
            [0, 0, 0],
            [255, 255, 0],  # (yellow) aquafarm
            [0, 255, 255],  # (cyan) lock
            [0, 255, 0],  # (green) dam
            [0, 0, 255],  # (blue) solar_farm
            [255, 0, 0],  # (red) power_plant
            [128, 0, 128],  # (purple) gas_station
            [255, 255, 255],  # (white) park
            [0, 128, 0],  # (dark green) parking_garage
            [128, 128, 128],  # (grey) parking_lot
            [165, 42, 42],  # (brown) landfill
            [128, 0, 0],  # (maroon) quarry
            [255, 165, 0],  # (orange) stadium
            [255, 105, 180],  # (pink) airport
            [192, 192, 192],  # (silver) airport_apron
            [173, 216, 230],  # (light blue) airport_hangar
            [32, 178, 170],  # (light sea green) airport_terminal
            [255, 0, 255],  # (magenta) ski_resort
            [128, 128, 0],  # (olive) theme_park
            [47, 79, 79],  # (dark slate gray) storage_tank
            [255, 215, 0],  # (gold) silo
            [192, 192, 192],  # (light grey) track
            [240, 230, 140],  # (khaki) wastewater_plant
            [154, 205, 50],  # (yellow green) power_substation
            [255, 165, 0],  # (orange) pier
            [0, 192, 0],  # (middle green) crop
            [0, 192, 0],  # (middle green) water_park
        ],
    },
    'wildfire': {
        'type': 'bin_segment',
        'categories': ['fire_retardant', 'burned'],
        'colors': [
            [255, 0, 0],  # (red) fire retardant
            [128, 128, 128],  # (grey) burned area
        ],
    },
    'smoke': {'type': 'classification', 'categories': ['no', 'partial', 'yes']},
    'snow': {'type': 'classification', 'categories': ['no', 'partial', 'yes']},
    'dem': {'type': 'regress', 'BackgroundInvalid': True},
    'airplane': {
        'type': 'detect',
        'categories': ['background', 'airplane'],
        'colors': [
            [0, 0, 0],  # (black) background
            [255, 0, 0],  # (red) airplane
        ],
    },
    'vessel': {
        'type': 'detect',
        'categories': ['background', 'vessel'],
        'colors': [
            [0, 0, 0],  # (black) background
            [255, 0, 0],  # (red) vessel
        ],
    },
    'water_event': {
        'type': 'segment',
        'BackgroundInvalid': True,
        'categories': ['invalid', 'background', 'water_event'],
        'colors': [
            [0, 0, 0],  # (black) invalid
            [0, 255, 0],  # (green) background
            [0, 0, 255],  # (blue) water_event
        ],
    },
    'park_sport': {
        'type': 'classification',
        'categories': [
            'american_football',
            'badminton',
            'baseball',
            'basketball',
            'cricket',
            'rugby',
            'soccer',
            'tennis',
            'volleyball',
        ],
    },
    'park_type': {
        'type': 'classification',
        'categories': ['park', 'pitch', 'golf_course', 'cemetery'],
    },
    'power_plant_type': {
        'type': 'classification',
        'categories': ['oil', 'nuclear', 'coal', 'gas'],
    },
    'quarry_resource': {
        'type': 'classification',
        'categories': ['sand', 'gravel', 'clay', 'coal', 'peat'],
    },
    'track_sport': {
        'type': 'classification',
        'categories': ['running', 'cycling', 'horse'],
    },
    'road_type': {
        'type': 'classification',
        'categories': [
            'motorway',
            'trunk',
            'primary',
            'secondary',
            'tertiary',
            'residential',
            'service',
            'track',
            'pedestrian',
        ],
    },
    'cloud': {
        'type': 'bin_segment',
        'categories': ['background', 'cloud', 'shadow'],
        'colors': [
            [0, 255, 0],  # (green) not clouds or shadows
            [255, 255, 255],  # (white) clouds
            [128, 128, 128],  # (grey) shadows
        ],
        'BackgroundInvalid': True,
    },
    'flood': {
        'type': 'bin_segment',
        'categories': ['background', 'water'],
        'colors': [
            [0, 255, 0],  # (green) background
            [0, 0, 255],  # (blue) water
        ],
        'BackgroundInvalid': True,
    },
}


[docs]class SatlasPretrain(NonGeoDataset): """SatlasPretrain dataset. `SatlasPretrain <https://satlas-pretrain.allen.ai/>`__ is a large-scale pre-training dataset for tasks that involve understanding satellite images. Regularly-updated satellite data is publicly available for much of the Earth through sources such as Sentinel-2 and NAIP, and can inform numerous applications from tackling illegal deforestation to monitoring marine infrastructure. However, developing automatic computer vision systems to parse these images requires a huge amount of manual labeling of training data. By combining over 30 TB of satellite images with 137 label categories, SatlasPretrain serves as an effective pre-training dataset that greatly reduces the effort needed to develop robust models for downstream satellite image applications. Reference implementation: * https://github.com/allenai/satlas/blob/main/satlas/model/dataset.py If you use this dataset in your research, please cite the following paper: * https://doi.org/10.48550/arXiv.2211.15660 .. versionadded:: 0.7 .. note:: This dataset requires the following additional library to be installed: * `AWS CLI <https://aws.amazon.com/cli/>`_: to download the dataset from AWS. """ # https://github.com/allenai/satlas/blob/main/satlaspretrain_urls.txt url = 's3://ai2-public-datasets/satlas/' tarballs: ClassVar[dict[str, tuple[str, ...]]] = { 'landsat': ('satlas-dataset-v1-landsat.tar',), 'naip': ( 'satlas-dataset-v1-naip-2011.tar', 'satlas-dataset-v1-naip-2012.tar', 'satlas-dataset-v1-naip-2013.tar', 'satlas-dataset-v1-naip-2014.tar', 'satlas-dataset-v1-naip-2015.tar', 'satlas-dataset-v1-naip-2016.tar', 'satlas-dataset-v1-naip-2017.tar', 'satlas-dataset-v1-naip-2018.tar', 'satlas-dataset-v1-naip-2019.tar', 'satlas-dataset-v1-naip-2020.tar', ), 'sentinel1': ('satlas-dataset-v1-sentinel1-new.tar',), 'sentinel2': ( 'satlas-dataset-v1-sentinel2-a.tar', 'satlas-dataset-v1-sentinel2-b.tar', ), 'static': ('satlas-dataset-v1-labels-static.tar',), 'dynamic': ('satlas-dataset-v1-labels-dynamic.tar',), 'metadata': ('satlas-dataset-v1-metadata.tar',), } md5s: ClassVar[dict[str, tuple[str, ...]]] = { 'landsat': ('89ea5e8974826c071908392827780a06',), 'naip': ( '523736842994861054f04b97c4d90bfb', '636b9a3b08be0e40d098cb7b5e655b57', '69e2b1052b1d2d465322a24cf7207a16', '38999aea424d403ad60e1398443636aa', '97f4855072a8a406a4bfbe94c5f7311c', '9ba3c626b23e6d26749a323eaedc7c0a', 'e4aba3d198dedfe1524a9338e85794aa', '74191a36d841b0b9b5d5cbae9a92ad71', '55b110cc6f734bf88793306d49f1c415', '97fc8414334987c59593d574f112a77e', ), 'sentinel1': ('3d88a0a10df6ab0aa50db2ba4c475048',), 'sentinel2': ( '7e1c6a1e322807fb11df8c0c062545ca', '6636b8ecf2fff1d6723ecfef55a4876d', ), 'static': ('4e38c2573bc78cf1f0d7267e432cb42c',), 'dynamic': ('4503ae687948e7d2cb7ade0083f77a8a',), 'metadata': ('6b9ac5a4f9a1ee88a271d28f12854607',), } # NOTE: 'tci' is RGB (b04-b02), not BGR (b02-b04) bands: ClassVar[dict[str, tuple[str, ...]]] = { 'landsat': tuple(f'b{i}' for i in range(1, 12)), 'naip': ('tci', 'ir'), 'sentinel1': ('vh', 'vv'), 'sentinel2': ('tci', 'b05', 'b06', 'b07', 'b08', 'b11', 'b12'), } chip_size = 512
[docs] def __init__( self, root: Path = 'data', split: str = 'train_lowres', good_images: str = 'good_images_lowres_all', image_times: str = 'image_times', images: Iterable[str] = ('sentinel1', 'sentinel2', 'landsat'), labels: Iterable[str] = ('land_cover',), transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize a new SatlasPretrain instance. Args: root: Root directory where dataset can be found. split: Metadata split to load. good_images: Metadata mapping between col/row and directory. image_times: Metadata mapping between directory and ISO time. images: List of image products. labels: List of label products. transforms: A function/transform that takes input sample and its target as entry and returns a transformed version. download: If True, download dataset and store it in the root directory. checksum: If True, check the MD5 of the downloaded files (may be slow). Raises: AssertionError: If *images* is invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ assert set(images) <= set(self.bands.keys()) self.root = root self.images = images self.labels = labels self.transforms = transforms self.download = download self.checksum = checksum self._verify() # Read metadata files self.split = pd.read_json( os.path.join(root, 'metadata', f'{split}.json'), typ='frame' ) self.good_images = pd.read_json( os.path.join(root, 'metadata', f'{good_images}.json'), typ='frame' ) self.image_times = pd.read_json( os.path.join(root, 'metadata', f'{image_times}.json'), typ='series' ) self.split.columns = ['col', 'row'] self.good_images.columns = ['col', 'row', 'directory'] self.good_images = self.good_images.groupby(['col', 'row'])
[docs] def __len__(self) -> int: """Return the number of locations in the dataset. Returns: Length of the dataset """ return len(self.split)
[docs] def __getitem__(self, index: int) -> dict[str, Tensor]: """Return an index within the dataset. Args: index: Index to return. Returns: Data and label at that index. """ col, row = self.split.iloc[index] directories = self.good_images.get_group((col, row))['directory'] sample: dict[str, Tensor] = {} for image in self.images: self._load_image(sample, image, col, row, directories) for label in self.labels: self._load_label(sample, label, col, row) if self.transforms is not None: sample = self.transforms(sample) return sample
def _load_image( self, sample: dict[str, Tensor], image: str, col: int, row: int, directories: pd.Series, ) -> None: """Load a single image. Args: sample: Dataset sample to populate. image: Image product. col: Web Mercator column. row: Web Mercator row. directories: Directories that may contain the image. """ # Find directories that match image product good_directories: list[str] = [] for directory in directories: path = os.path.join(self.root, image, directory) if os.path.isdir(path): good_directories.append(directory) # Choose a random timestamp idx = torch.randint(len(good_directories), (1,)) directory = good_directories[idx] time = self.image_times[directory].timestamp() sample[f'time_{image}'] = torch.tensor(time) # Load all bands resample = Image.Resampling.BILINEAR channels = [] for band in self.bands[image]: path = os.path.join(self.root, image, directory, band, f'{col}_{row}.png') with Image.open(path) as img: img = img.resize((self.chip_size, self.chip_size), resample=resample) array = np.atleast_3d(np.array(img, dtype=np.float32)) channels.append(torch.tensor(array)) raster = rearrange(torch.cat(channels, dim=-1), 'h w c -> c h w') sample[f'image_{image}'] = raster def _load_label( self, sample: dict[str, Tensor], label: str, col: int, row: int ) -> None: """Load a single label. Args: sample: Dataset sample to populate. label: Label product. col: Web Mercator column. row: Web Mercator row. """ path = os.path.join(self.root, 'static', f'{col}_{row}', f'{label}.png') if os.path.isfile(path): with Image.open(path) as img: raster = torch.tensor(np.array(img, dtype=np.int64)) else: raster = torch.zeros(self.chip_size, self.chip_size, dtype=torch.long) sample[f'mask_{label}'] = raster def _verify(self) -> None: """Verify the integrity of the dataset.""" products = [*self.images, 'metadata'] if self.labels: products.append('static') for product in products: # Check if the extracted directory already exists if os.path.isdir(os.path.join(self.root, product)): continue tarballs = self.tarballs[product] md5s = self.md5s[product] for tarball, md5 in zip(tarballs, md5s): path = os.path.join(self.root, tarball) # Check if the tarball has already been downloaded if os.path.isfile(path): extract_archive(path) continue # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) # Download and extract the tarball aws = which('aws') aws('s3', 'cp', self.url + tarball, self.root) check_integrity(path, md5 if self.checksum else None) extract_archive(path)
[docs] def plot( self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. Args: sample: A sample returned by :meth:`__getitem__`. show_titles: Flag indicating whether to show titles above each panel. suptitle: Optional string to use as a suptitle. Returns: A matplotlib Figure with the rendered sample. """ images = [] titles = [] for key, value in sample.items(): match key.split('_', 1): case ['image', 'landsat']: images.append(rearrange(value[[3, 2, 1]], 'c h w -> h w c') / 255) titles.append('Landsat 8/9') case ['image', 'naip']: images.append(rearrange(value[:3], 'c h w -> h w c') / 255) titles.append('NAIP') case ['image', 'sentinel1']: images.extend([value[0] / 255, value[1] / 255]) titles.extend(['Sentinel-1 VH', 'Sentinel-1 VV']) case ['image', 'sentinel2']: images.append(rearrange(value[:3], 'c h w -> h w c') / 255) titles.append('Sentinel-2') case ['mask' | 'prediction', label]: cmap = torch.tensor(TASKS[label]['colors']) images.append(cmap[value]) titles.append(label.replace('_', ' ').capitalize()) fig, ax = plt.subplots(ncols=len(images), squeeze=False) for i, (image, title) in enumerate(zip(images, titles)): ax[0, i].imshow(image) ax[0, i].axis('off') if show_titles: ax[0, i].set_title(title) if suptitle is not None: fig.suptitle(suptitle) return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources