Shortcuts

Source code for torchgeo.datasets.substation

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""Substation segmentation dataset."""

import glob
import os
from collections.abc import Callable, Sequence
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib.figure import Figure
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, download_url, extract_archive


[docs]class Substation(NonGeoDataset): """Substation dataset. The `Substation <https://github.com/Lindsay-Lab/substation-seg>`__ dataset is curated by TransitionZero and sourced from publicly available data repositories, including OpenSreetMap (OSM) and Copernicus Sentinel data. The dataset consists of Sentinel-2 images from 27k+ locations; the task is to segment power-substations, which appear in the majority of locations in the dataset. Most locations have 4-5 images taken at different timepoints (i.e., revisits). Dataset Format: * .npz file for each datapoint Dataset Features: * 26,522 image-mask pairs stored as numpy files. * Data from 5 revisits for most locations. * Multi-temporal, multi-spectral images (13 channels) paired with masks, with a spatial resolution of 228x228 pixels If you use this dataset in your research, please cite the following paper: * https://doi.org/10.48550/arXiv.2409.17363 """ directory = 'Substation' filename_images = 'image_stack.tar.gz' filename_masks = 'mask.tar.gz' url_for_images = 'https://storage.googleapis.com/tz-ml-public/substation-over-10km2-csv-main-444e360fd2b6444b9018d509d0e4f36e/image_stack.tar.gz' url_for_masks = 'https://storage.googleapis.com/tz-ml-public/substation-over-10km2-csv-main-444e360fd2b6444b9018d509d0e4f36e/mask.tar.gz' md5_images = '948706609864d0283f74ee7015f9d032' md5_masks = 'baa369ececdc2ff80e6ba2b4c7fe147c'
[docs] def __init__( self, root: Path = 'data', bands: Sequence[int] = tuple(range(13)), mask_2d: bool = True, num_of_timepoints: int = 4, timepoint_aggregation: Literal['concat', 'median', 'first', 'random'] | None = 'concat', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize the Substation. Args: root: Path to the directory containing the dataset. bands: Channels to use from the image. mask_2d: Whether to use a 2D mask. num_of_timepoints: Number of timepoints to use for each image. timepoint_aggregation: How to aggregate multiple timepoints. transforms: A transform takes input sample and returns a transformed version. download: Whether to download the dataset if it is not found. checksum: Whether to verify the dataset after downloading. """ self.root = root self.bands = bands self.mask_2d = mask_2d self.num_of_timepoints = num_of_timepoints self.timepoint_aggregation = timepoint_aggregation self.transforms = transforms self.download = download self.checksum = checksum self.image_dir = os.path.join(root, 'image_stack') self.mask_dir = os.path.join(root, 'mask') self._verify() self.image_filenames = pd.Series(sorted(os.listdir(self.image_dir)))
[docs] def __getitem__(self, index: int) -> dict[str, Tensor]: """Get an item from the dataset by index. Args: index: Index of the item to retrieve. Returns: A dictionary containing the image and corresponding mask. """ image_filename = self.image_filenames[index] image_path = os.path.join(self.image_dir, image_filename) mask_path = os.path.join(self.mask_dir, image_filename) image = np.load(image_path)['arr_0'] # selecting channels image = image[:, self.bands, :, :] # handling multiple images across timepoints if image.shape[0] < self.num_of_timepoints: # Padding: cycle through existing timepoints padded_images = [] for i in range(self.num_of_timepoints): padded_images.append(image[i % image.shape[0]]) image = np.stack(padded_images) elif image.shape[0] > self.num_of_timepoints: # Removal: take the most recent timepoints image = image[-self.num_of_timepoints :] match self.timepoint_aggregation: case 'concat': # (num_of_timepoints*channels, h, w) image = np.reshape(image, (-1, image.shape[2], image.shape[3])) case 'median': image = np.median(image, axis=0) case 'first': image = image[0] case 'random': image = image[np.random.randint(image.shape[0])] mask = np.load(mask_path)['arr_0'] mask[mask != 3] = 0 mask[mask == 3] = 1 image = torch.from_numpy(image) mask = torch.from_numpy(mask).long() mask = mask.unsqueeze(dim=0) if self.mask_2d: mask_0 = 1.0 - mask mask = torch.concat([mask_0, mask], dim=0) mask = mask.squeeze() sample = {'image': image, 'mask': mask} if self.transforms is not None: sample = self.transforms(sample) return sample
[docs] def __len__(self) -> int: """Returns the number of items in the dataset.""" return len(self.image_filenames)
[docs] def plot( self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. Args: sample: a sample returned by :meth:`__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle Returns: A matplotlib Figure containing the rendered sample. """ ncols = 2 shape_of_image = sample['image'].shape if len(shape_of_image) == 4: # Plot the first timepoint image = sample['image'][0][:3].permute(1, 2, 0).cpu().numpy() else: image = sample['image'][:3].permute(1, 2, 0).cpu().numpy() image = image / 255.0 if self.mask_2d: mask = sample['mask'][0].squeeze(0).cpu().numpy() else: mask = sample['mask'].cpu().numpy() showing_predictions = 'prediction' in sample if showing_predictions: prediction = sample['prediction'].cpu().numpy() if self.mask_2d: prediction = prediction[0] ncols = 3 fig, axs = plt.subplots(ncols=ncols, figsize=(4 * ncols, 4)) axs[0].imshow(image) axs[0].axis('off') axs[1].imshow(mask, cmap='gray', interpolation='none') axs[1].axis('off') if show_titles: axs[0].set_title('Image') axs[1].set_title('Mask') if showing_predictions: axs[2].imshow(prediction, cmap='gray', interpolation='none') axs[2].axis('off') if show_titles: axs[2].set_title('Prediction') if suptitle: fig.suptitle(suptitle) return fig
def _extract(self) -> None: """Extract the dataset.""" img_pathname = os.path.join(self.root, self.filename_images) extract_archive(img_pathname) mask_pathname = os.path.join(self.root, self.filename_masks) extract_archive(mask_pathname) def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist image_path = os.path.join(self.image_dir, '*.npz') mask_path = os.path.join(self.mask_dir, '*.npz') if glob.glob(image_path) and glob.glob(mask_path): return # Check if the tar.gz files for images and masks have already been downloaded image_exists = os.path.exists(os.path.join(self.root, self.filename_images)) mask_exists = os.path.exists(os.path.join(self.root, self.filename_masks)) if image_exists and mask_exists: self._extract() return # If dataset files are missing and download is not allowed, raise an error if not self.download: raise DatasetNotFoundError(self) # Download and extract the dataset self._download() self._extract() def _download(self) -> None: """Download the dataset and extract it.""" # Download and verify images download_url( self.url_for_images, self.root, filename=self.filename_images, md5=self.md5_images if self.checksum else None, ) extract_archive(os.path.join(self.root, self.filename_images), self.root) # Download and verify masks download_url( self.url_for_masks, self.root, filename=self.filename_masks, md5=self.md5_masks if self.checksum else None, ) extract_archive(os.path.join(self.root, self.filename_masks), self.root)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources