Shortcuts

Source code for torchgeo.datasets.solar_plants_brazil

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""SolarPlantsBrazil dataset."""

import glob
import os
from collections.abc import Callable
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
import rasterio
import torch
from matplotlib.figure import Figure
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, download_and_extract_archive, extract_archive


[docs]class SolarPlantsBrazil(NonGeoDataset): """Solar Plants Brazil dataset (semantic segmentation for photovoltaic detection). The `Solar Plants Brazil <https://huggingface.co/datasets/FederCO23/solar-plants-brazil>`__ dataset provides satellite imagery and pixel-level annotations for detecting photovoltaic solar power stations. Dataset features: * 272 RGB+NIR GeoTIFF images (256x256 pixels) * Binary masks indicating presence of solar panels (1 = panel, 0 = background) * Organized into `train`, `val`, and `test` splits * Float32 GeoTIFF files for both input and mask images * Spatial metadata included (CRS, bounding box), but not used directly for training Folder structure: .. code-block:: text root/train/input/img(123).tif root/train/labels/target(123).tif Access: * Dataset is hosted on Hugging Face: https://huggingface.co/datasets/FederCO23/solar-plants-brazil * Code and preprocessing steps available at: https://github.com/FederCO23/UCSD_MLBootcamp_Capstone .. versionadded:: 0.8 """ url = 'https://huggingface.co/datasets/FederCO23/solar-plants-brazil/resolve/1dc13a453ef6acabf08a1781c523fd1db3d9bcc5/' filename = 'solarplantsbrazil.zip' bands = ('Red', 'Green', 'Blue', 'NIR') md5 = 'dfa0d3efdef4143a33b0e7ba834eaafa'
[docs] def __init__( self, root: Path = 'data', split: Literal['train', 'val', 'test'] = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize a SolarPlantsBrazil dataset split. Args: root: root directory where dataset can be found split: dataset split to use, one of "train", "val", or "test" transforms: a function/transform that takes an input sample and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: DatasetNotFoundError: If dataset is not found and *download* is False. ValueError: If *split* is invalid. """ if split not in ['train', 'val', 'test']: raise ValueError( f"Invalid split '{split}', expected one of: 'train', 'val', or 'test'" ) self.root = root self.transforms = transforms self.dataset_path = os.path.join(self.root, split) self.split = split self.download = download self.checksum = checksum self._verify() input_dir = os.path.join(self.dataset_path, 'input') labels_dir = os.path.join(self.dataset_path, 'labels') self.image_paths = sorted(glob.glob(os.path.join(input_dir, 'img(*).tif'))) self.mask_paths = sorted(glob.glob(os.path.join(labels_dir, 'target(*).tif')))
def _verify(self) -> None: """Verify the integrity of the dataset.""" directory = os.path.join(self.root, self.split) zip_path = os.path.join(self.root, self.filename) if os.path.isdir(directory): return elif os.path.isfile(zip_path): extract_archive(zip_path) elif self.download: self._download() else: raise DatasetNotFoundError(self) def _download(self) -> None: """Download and extract the dataset archive.""" download_and_extract_archive( url=self.url + self.filename, download_root=self.root, filename=self.filename, md5=self.md5 if self.checksum else None, )
[docs] def __getitem__(self, index: int) -> dict[str, Tensor]: """Return the image and mask at the given index. Args: index: index of the image and mask to return Returns: image and mask at given index """ image = self._load_image(self.image_paths[index]) mask = self._load_mask(self.mask_paths[index]) sample = {'image': image, 'mask': mask} if self.transforms: sample = self.transforms(sample) return sample
[docs] def __len__(self) -> int: """Return the number of samples in the dataset. Returns: The number of image-mask pairs in the dataset. """ return len(self.image_paths)
def _load_image(self, path: str) -> Tensor: """Load an image as a float32 torch tensor. Args: path: Path to the input image file. Returns: A float32 tensor with shape (C, H, W). """ with rasterio.open(path) as src: arr = src.read().astype(np.float32) return torch.from_numpy(arr) def _load_mask(self, path: str) -> Tensor: """Load a binary mask from file and return as a tensor. Args: path: Path to the binary mask file. Returns: A long tensor with shape (H, W), with values 0 or 1. """ with rasterio.open(path) as src: arr = src.read(1).astype(np.uint8) bin_mask = (arr > 0).astype(np.uint8) return torch.from_numpy(bin_mask).long()
[docs] def plot( self, sample: dict[str, torch.Tensor], suptitle: str | None = None ) -> Figure: """Plot a sample from the SolarPlantsBrazil dataset. Args: sample: A dictionary with 'image' and 'mask' tensors. suptitle: Optional string to use as a suptitle. Returns: A matplotlib Figure with the rendered image and mask. """ image = sample['image'] mask = sample['mask'] # Use RGB only if image.shape[0] == 4: image = image[:3] # Normalize for display image_np = image.numpy() image_np = image_np / np.max(image_np) image_np = np.transpose(image_np, (1, 2, 0)) mask_np = mask.squeeze().numpy() fig, axs = plt.subplots(1, 2, figsize=(10, 5)) axs[0].imshow(image_np) axs[0].set_title('RGB Image') axs[0].axis('off') axs[1].imshow(mask_np, cmap='gray') axs[1].set_title('Mask') axs[1].axis('off') if suptitle is not None: plt.suptitle(suptitle) plt.tight_layout() return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources