Shortcuts

Source code for torchgeo.datasets.digital_typhoon

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""Digital Typhoon dataset."""

import glob
import os
import tarfile
from collections.abc import Callable, Sequence
from typing import Any, ClassVar, TypedDict

import einops
import matplotlib.pyplot as plt
import pandas as pd
import torch
from matplotlib.figure import Figure
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, download_url, lazy_import, percentile_normalization


class _SampleSequenceDict(TypedDict):
    """Sample sequence dictionary."""

    id: str
    seq_id: list[int]


[docs]class DigitalTyphoon(NonGeoDataset): """Digital Typhoon Dataset for Analysis Task. This dataset contains typhoon-centered images, derived from hourly infrared channel images captured by meteorological satellites. It incorporates data from multiple generations of the Himawari weather satellite, dating back to 1978. These images have been transformed into brightness temperatures and adjusted for varying satellite sensor readings, yielding a consistent spatio-temporal dataset that covers over four decades. See `the Digital Typhoon website <http://agora.ex.nii.ac.jp/digital-typhoon/dataset/>`_ for more information about the dataset. Dataset features: * infrared channel images from the Himawari weather satellite (512x512 px) at 5km spatial resolution * auxiliary features such as wind speed, pressure, and more that can be used for regression or classification tasks * 1,099 typhoons and 189,364 images Dataset format: * hdf5 files containing the infrared channel images * .csv files containing the metadata for each image If you use this dataset in your research, please cite the following papers: * https://doi.org/10.20783/DIAS.664 .. versionadded:: 0.6 """ valid_tasks = ('classification', 'regression') aux_file_name = 'aux_data.csv' valid_features = ( 'year', 'month', 'day', 'hour', 'grade', 'lat', 'lng', 'pressure', 'wind', 'dir50', 'long50', 'short50', 'dir30', 'long30', 'short30', 'landfall', 'intp', ) url = 'https://hf.co/datasets/torchgeo/digital_typhoon/resolve/cf2f9ef89168d31cb09e42993d35b068688fe0df/WP.tar.gz{0}' md5sums: ClassVar[dict[str, str]] = { 'aa': '9e77a5f74783f7909dee0fb936053b17', 'ab': '46aebdcba6e4e2df1619e4a3d7e556bb', } min_input_clamp = 170.0 max_input_clamp = 300.0 data_root = 'WP'
[docs] def __init__( self, root: Path = 'data', task: str = 'regression', features: Sequence[str] = ['wind'], targets: Sequence[str] = ['wind'], sequence_length: int = 3, min_feature_value: dict[str, float] | None = None, max_feature_value: dict[str, float] | None = None, transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize a new Digital Typhoon dataset instance. Args: root: root directory where dataset can be found task: whether to load 'regression' or 'classification' labels features: which auxiliary features to return targets: which auxiliary features to use as targets sequence_length: length of the sequence to return min_feature_value: minimum value for each feature max_feature_value: maximum value for each feature transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: If any arguments are invalid. DatasetNotFoundError: If dataset is not found and *download* is False. DependencyNotFoundError: If h5py is not installed. """ lazy_import('h5py') self.root = root self.transforms = transforms self.download = download self.checksum = checksum self.sequence_length = sequence_length self.min_feature_value = min_feature_value self.max_feature_value = max_feature_value assert task in self.valid_tasks, ( f'Please choose one of {self.valid_tasks}, you provided {task}.' ) self.task = task assert set(features).issubset(set(self.valid_features)) self.features = features assert set(targets).issubset(set(self.valid_features)) self.targets = targets self._verify() self.aux_df = pd.read_csv( os.path.join(root, self.data_root, self.aux_file_name) ) self.aux_df['datetime'] = pd.to_datetime( self.aux_df[['year', 'month', 'day', 'hour']] ) self.aux_df = self.aux_df.sort_values(['year', 'month', 'day', 'hour']) self.aux_df['seq_id'] = self.aux_df.groupby(['id']).cumcount() self.aux_df.columns = [str(col) for col in self.aux_df.columns] # Compute the hour difference between consecutive images per typhoon id self.aux_df['hour_diff_consecutive'] = ( self.aux_df.sort_values(['id', 'datetime']) .groupby('id')['datetime'] .diff() .dt.total_seconds() / 3600 ) # Compute the hour difference between the first and second entry self.aux_df['hour_diff_to_next'] = ( self.aux_df.groupby('id')['datetime'] .shift(-1) .sub(self.aux_df['datetime']) .abs() .dt.total_seconds() / 3600 ) self.aux_df['hour_diff'] = self.aux_df['hour_diff_consecutive'].combine_first( self.aux_df['hour_diff_to_next'] ) self.aux_df.drop( ['hour_diff_consecutive', 'hour_diff_to_next'], axis=1, inplace=True ) # 0 hour difference is for the last time step of each typhoon sequence and want # to keep only images that have max 1 hour difference self.aux_df = self.aux_df[self.aux_df['hour_diff'] <= 1] # Filter out all ids that only have less than sequence_length entries self.aux_df = self.aux_df.groupby('id').filter( lambda x: len(x) >= self.sequence_length ) # Filter aux_df according to min_target_value if self.min_feature_value is not None: for feature, min_value in self.min_feature_value.items(): self.aux_df = self.aux_df[self.aux_df[feature] >= min_value] # Filter aux_df according to max_target_value if self.max_feature_value is not None: for feature, max_value in self.max_feature_value.items(): self.aux_df = self.aux_df[self.aux_df[feature] <= max_value] # collect target mean and std for each target self.target_mean: dict[str, float] = self.aux_df[self.targets].mean().to_dict() self.target_std: dict[str, float] = self.aux_df[self.targets].std().to_dict() def _get_subsequences(df: pd.DataFrame, k: int) -> list[dict[str, list[int]]]: """Generate all possible subsequences of length k for a given group. Args: df: grouped dataframe of a single typhoon k: length of the subsequences to generate Returns: list of all possible subsequences of length k for a given typhoon id """ min_seq_id = df['seq_id'].min() max_seq_id = df['seq_id'].max() # generate possible subsquences of length k for group subsequences = [ {'id': df['id'].iloc[0], 'seq_id': list(range(i, i + k))} for i in range(min_seq_id, max_seq_id - k + 2) ] return [ subseq for subseq in subsequences if set(subseq['seq_id']).issubset(df['seq_id']) ] self.sample_sequences: list[_SampleSequenceDict] = [ item for sublist in self.aux_df.groupby('id')[['seq_id', 'id']] .apply(_get_subsequences, k=self.sequence_length) .tolist() for item in sublist ]
[docs] def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. Args: index: index to return Returns: data, labels, and metadata at that index """ sample_entry = self.sample_sequences[index] sample_df = self.aux_df[ (self.aux_df['id'] == sample_entry['id']) & (self.aux_df['seq_id'].isin(sample_entry['seq_id'])) ] sample = {'image': self._load_image(sample_df)} # load features of the last image in the sequence sample.update( self._load_features( os.path.join( self.root, self.data_root, 'metadata', str(sample_df.iloc[-1]['id']) + '.csv', ), sample_df.iloc[-1]['image_path'], ) ) # torchgeo expects a single label sample['label'] = torch.Tensor([sample[target] for target in self.targets]) if self.transforms is not None: sample = self.transforms(sample) return sample
[docs] def __len__(self) -> int: """Return the number of data points in the dataset. Returns: length of the dataset """ return len(self.sample_sequences)
def _load_image(self, sample_df: pd.DataFrame) -> Tensor: """Load a single image. Args: sample_df: df holding all information necessary to load the consecutive images in the sequence Returns: concatenation of all images in the sequence over channel dimension """ def load_image_tensor(id: str, filepath: str) -> Tensor: """Load a single image tensor from a h5 file. Args: id: typhoon id filepath: path to the h5 file Returns: image tensor """ h5py = lazy_import('h5py') full_path = os.path.join(self.root, self.data_root, 'image', id, filepath) with h5py.File(full_path, 'r') as h5f: # tensor with added channel dimension tensor = torch.from_numpy(h5f['Infrared'][:]).unsqueeze(0) # follow normalization procedure # https://github.com/kitamoto-lab/benchmarks/blob/1bdbefd7c570cb1bdbdf9e09f9b63f7c22bbdb27/analysis/regression/FrameDatamodule.py#L94 tensor = torch.clamp(tensor, self.min_input_clamp, self.max_input_clamp) tensor = (tensor - self.min_input_clamp) / ( self.max_input_clamp - self.min_input_clamp ) return tensor # tensor of shape [sequence_length, height, width] tensor = torch.cat( [ load_image_tensor(str(id), filepath) for id, filepath in zip(sample_df['id'], sample_df['image_path']) ] ).float() return tensor def _load_features(self, filepath: str, image_path: str) -> dict[str, Any]: """Load features for the corresponding image. Args: filepath: path of the feature file to load image_path: image path for the unique image for which to retrieve features Returns: features for image """ feature_df = pd.read_csv(filepath) feature_df = feature_df[feature_df['file_1'] == image_path] feature_dict = { name: torch.tensor(feature_df[name].item()).float() for name in self.features } # normalize the targets for regression if self.task == 'regression': for feature, mean in self.target_mean.items(): feature_dict[feature] = ( feature_dict[feature] - mean ) / self.target_std[feature] return feature_dict def _verify(self) -> None: """Verify the integrity of the dataset.""" # Check if the extracted files already exist exists = [] path = os.path.join(self.root, self.data_root, 'image', '*', '*.h5') if glob.glob(path): exists.append(True) else: exists.append(False) # check if aux.csv file exists exists.append( os.path.exists(os.path.join(self.root, self.data_root, self.aux_file_name)) ) if all(exists): return # Check if the tar.gz files have already been downloaded exists = [] for suffix in self.md5sums.keys(): path = os.path.join(self.root, f'{self.data_root}.tar.gz{suffix}') exists.append(os.path.exists(path)) if all(exists): self._extract() return # Check if the user requested to download the dataset if not self.download: raise DatasetNotFoundError(self) # Download amd extract the dataset self._download() self._extract() def _download(self) -> None: """Download the dataset.""" for suffix, md5 in self.md5sums.items(): download_url( self.url.format(suffix), self.root, md5=md5 if self.checksum else None ) def _extract(self) -> None: """Extract the dataset.""" # Extract tarball for suffix in self.md5sums.keys(): with tarfile.open( os.path.join(self.root, f'{self.data_root}.tar.gz{suffix}') ) as tar: tar.extractall(path=self.root)
[docs] def plot( self, sample: dict[str, Any], show_titles: bool = True, suptitle: str | None = None, ) -> Figure: """Plot a sample from the dataset. Args: sample: a sample return by :meth:`__getitem__` show_titles: flag indicating whether to show titles above each panel suptitle: optional suptitle to use for figure Returns: a matplotlib Figure with the rendered sample """ image, label = sample['image'].numpy(), sample['label'].numpy() image = percentile_normalization(image) image = einops.rearrange(image, 'c h w -> h w c') showing_predictions = 'prediction' in sample if showing_predictions: prediction = sample['prediction'] fig, ax = plt.subplots(1, 1, figsize=(10, 10)) ax.imshow(image) ax.axis('off') if show_titles: title_dict = { label_name: label[idx].item() for idx, label_name in enumerate(self.targets) } title = f'Label: {title_dict}' if showing_predictions: title_dict = { label_name: prediction[idx].item() for idx, label_name in enumerate(self.targets) } title += f'\nPrediction: {title_dict}' ax.set_title(title) if suptitle is not None: plt.suptitle(suptitle) return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources