Shortcuts

Source code for torchgeo.datasets.skyscript

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""SkyScript dataset."""

import os
from collections.abc import Callable
from typing import Any, ClassVar

import numpy as np
import pandas as pd
import torch
from einops import rearrange
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, download_and_extract_archive, download_url, extract_archive


[docs]class SkyScript(NonGeoDataset): """SkyScript dataset. `SkyScript <https://github.com/wangzhecheng/SkyScript>`__ is a large and semantically diverse image-text dataset for remote sensing images. It contains 5.2 million remote sensing image-text pairs in total, covering more than 29K distinct semantic tags. If you use this dataset in your research, please cite it using the following format: * https://arxiv.org/abs/2312.12856 .. versionadded:: 0.6 """ url = 'https://opendatasharing.s3.us-west-2.amazonaws.com/SkyScript/{}' image_dirs = tuple(f'images{i}' for i in range(2, 8)) image_md5s = ( 'fbfb5f7aa1731f4106fc3ffbd608100a', 'ad4fd9fdb9622d1ea360210cb222f2bd', 'aeeb41e830304c74b14b5ffc1fc8e8c3', '02ee7e0e59f9ac1c87b678a155e1f1df', '350475f1e7fb996152fa16db891b4142', '5e2fbf3e9262b36e30b458ec9a1df625', ) #: Can be modified in subclasses to change train/val/test split caption_files: ClassVar[dict[str, str]] = { 'train': 'SkyScript_train_top30pct_filtered_by_CLIP_openai.csv', 'val': 'SkyScript_val_5K_filtered_by_CLIP_openai.csv', 'test': 'SkyScript_test_30K_filtered_by_CLIP_openai.csv', } caption_md5s: ClassVar[dict[str, str]] = { 'train': '05b362e43a852667b5374c9a5ae53f8e', 'val': 'c8d278fd29b754361989d5e7a6608f69', 'test': '0135d9b49ce6751360912a4353e809dc', }
[docs] def __init__( self, root: Path = 'data', split: str = 'train', transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize a new SkyScript instance. Args: root: Root directory where dataset can be found. split: One of 'train', 'val', 'test'. transforms: A function/transform that takes input sample and its target as entry and returns a transformed version. download: If True, download dataset and store it in the root directory. checksum: If True, check the MD5 of the downloaded files (may be slow). Raises: AssertionError: If *split* is invalid. DatasetNotFoundError: If dataset is not found and *download* is False. """ assert split in self.caption_files self.root = root self.split = split self.transforms = transforms self.download = download self.checksum = checksum self._verify() self.captions = pd.read_csv(os.path.join(self.root, self.caption_files[split]))
[docs] def __len__(self) -> int: """Return the number of images in the dataset. Returns: Length of the dataset. """ return len(self.captions)
[docs] def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. Args: index: Index to return. Returns: A dict containing image and caption at index. """ filepath, title = self.captions.iloc[index][:2] with Image.open(os.path.join(self.root, filepath)) as img: array = np.array(img, dtype=np.float32) array = rearrange(array, 'h w c -> c h w') image = torch.from_numpy(array) sample = {'image': image, 'caption': title} if self.transforms is not None: sample = self.transforms(sample) return sample
def _verify(self) -> None: """Verify the integrity of the dataset.""" md5: str | None for directory, md5 in zip(self.image_dirs, self.image_md5s): # Check if the extracted files already exist if os.path.isdir(os.path.join(self.root, directory)): continue # Check if the zip files have already been downloaded if os.path.isfile(os.path.join(self.root, f'{directory}.zip')): extract_archive(os.path.join(self.root, f'{directory}.zip')) continue # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) # Download the dataset url = self.url.format(f'{directory}.zip') md5 = md5 if self.checksum else None download_and_extract_archive(url, self.root, md5=md5) # Download the caption file if self.download: url = self.url.format(self.caption_files[self.split]) md5 = self.caption_md5s[self.split] if self.checksum else None download_url(url, self.root, md5=md5)
[docs] def plot( self, sample: dict[str, Any], show_titles: bool = True, suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. Args: sample: a sample returned by :meth:`RasterDataset.__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional string to use as a suptitle Returns: a matplotlib Figure with the rendered sample """ fig, ax = plt.subplots() image = rearrange(sample['image'], 'c h w -> h w c') / 255 ax.imshow(image) ax.axis('off') if show_titles: title = sample['caption'] if 'prediction' in sample: title += '\n' + sample['prediction'] ax.set_title(title) if suptitle is not None: plt.suptitle(suptitle) return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources