Source code for torchgeo.datasets.utils
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
"""Common dataset utilities."""
# https://github.com/sphinx-doc/sphinx/issues/11327
from __future__ import annotations
import collections
import contextlib
import importlib
import os
import shutil
import subprocess
from collections.abc import Iterable, Iterator, Mapping, MutableMapping, Sequence
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any, TypeAlias, cast, overload
import numpy as np
import pandas as pd
import rasterio
import shapely
import torch
from rasterio import Affine
from torch import Tensor
from torchvision.datasets.utils import (
check_integrity,
download_and_extract_archive,
download_url,
extract_archive,
)
from torchvision.utils import draw_segmentation_masks
from typing_extensions import deprecated
from .errors import DependencyNotFoundError
# Only include import redirects
__all__ = (
'check_integrity',
'download_and_extract_archive',
'download_url',
'extract_archive',
)
GeoSlice: TypeAlias = (
slice | tuple[slice] | tuple[slice, slice] | tuple[slice, slice, slice]
)
Path: TypeAlias = str | os.PathLike[str]
@deprecated('Use torchgeo.datasets.utils.GeoSlice or shapely.Polygon instead')
@dataclass(frozen=True)
class BoundingBox:
"""Data class for indexing spatiotemporal data."""
#: western boundary
minx: float
#: eastern boundary
maxx: float
#: southern boundary
miny: float
#: northern boundary
maxy: float
#: earliest boundary
mint: datetime
#: latest boundary
maxt: datetime
def __post_init__(self) -> None:
"""Validate the arguments passed to :meth:`__init__`.
Raises:
ValueError: if bounding box is invalid
(minx > maxx, miny > maxy, or mint > maxt)
.. versionadded:: 0.2
"""
if self.minx > self.maxx:
raise ValueError(
f"Bounding box is invalid: 'minx={self.minx}' > 'maxx={self.maxx}'"
)
if self.miny > self.maxy:
raise ValueError(
f"Bounding box is invalid: 'miny={self.miny}' > 'maxy={self.maxy}'"
)
if self.mint > self.maxt:
raise ValueError(
f"Bounding box is invalid: 'mint={self.mint}' > 'maxt={self.maxt}'"
)
@overload
def __getitem__(self, key: int) -> Any:
pass
@overload
def __getitem__(self, key: slice) -> list[Any]:
pass
def __getitem__(self, key: int | slice) -> Any | list[Any]:
"""Index the (minx, maxx, miny, maxy, mint, maxt) tuple.
Args:
key: integer or slice object
Returns:
the value(s) at that index
Raises:
IndexError: if key is out of bounds
"""
return [self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt][key]
def __iter__(self) -> Iterator[Any]:
"""Container iterator.
Returns:
iterator object that iterates over all objects in the container
"""
yield from [self.minx, self.maxx, self.miny, self.maxy, self.mint, self.maxt]
def __contains__(self, other: BoundingBox) -> bool:
"""Whether or not other is within the bounds of this bounding box.
Args:
other: another bounding box
Returns:
True if other is within this bounding box, else False
.. versionadded:: 0.2
"""
return (
(self.minx <= other.minx <= self.maxx)
and (self.minx <= other.maxx <= self.maxx)
and (self.miny <= other.miny <= self.maxy)
and (self.miny <= other.maxy <= self.maxy)
and (self.mint <= other.mint <= self.maxt)
and (self.mint <= other.maxt <= self.maxt)
)
def __or__(self, other: BoundingBox) -> BoundingBox:
"""The union operator.
Args:
other: another bounding box
Returns:
the minimum bounding box that contains both self and other
.. versionadded:: 0.2
"""
return BoundingBox(
min(self.minx, other.minx),
max(self.maxx, other.maxx),
min(self.miny, other.miny),
max(self.maxy, other.maxy),
min(self.mint, other.mint),
max(self.maxt, other.maxt),
)
def __and__(self, other: BoundingBox) -> BoundingBox:
"""The intersection operator.
Args:
other: another bounding box
Returns:
the intersection of self and other
Raises:
ValueError: if self and other do not intersect
.. versionadded:: 0.2
"""
try:
return BoundingBox(
max(self.minx, other.minx),
min(self.maxx, other.maxx),
max(self.miny, other.miny),
min(self.maxy, other.maxy),
max(self.mint, other.mint),
min(self.maxt, other.maxt),
)
except ValueError:
raise ValueError(f'Bounding boxes {self} and {other} do not overlap')
@property
def area(self) -> float:
"""Area of bounding box.
Area is defined as spatial area.
Returns:
area
.. versionadded:: 0.3
"""
return (self.maxx - self.minx) * (self.maxy - self.miny)
@property
def volume(self) -> timedelta:
"""Volume of bounding box.
Volume is defined as spatial area times temporal range.
Returns:
volume
.. versionadded:: 0.3
"""
return self.area * (self.maxt - self.mint)
def intersects(self, other: BoundingBox) -> bool:
"""Whether or not two bounding boxes intersect.
Args:
other: another bounding box
Returns:
True if bounding boxes intersect, else False
"""
return (
self.minx <= other.maxx
and self.maxx >= other.minx
and self.miny <= other.maxy
and self.maxy >= other.miny
and self.mint <= other.maxt
and self.maxt >= other.mint
)
def split(
self, proportion: float, horizontal: bool = True
) -> tuple[BoundingBox, BoundingBox]:
"""Split BoundingBox in two.
Args:
proportion: split proportion in range (0,1)
horizontal: whether the split is horizontal or vertical
Returns:
A tuple with the resulting BoundingBoxes
.. versionadded:: 0.5
"""
if not (0.0 < proportion < 1.0):
raise ValueError('Input proportion must be between 0 and 1.')
if horizontal:
w = self.maxx - self.minx
splitx = self.minx + w * proportion
bbox1 = BoundingBox(
self.minx, splitx, self.miny, self.maxy, self.mint, self.maxt
)
bbox2 = BoundingBox(
splitx, self.maxx, self.miny, self.maxy, self.mint, self.maxt
)
else:
h = self.maxy - self.miny
splity = self.miny + h * proportion
bbox1 = BoundingBox(
self.minx, self.maxx, self.miny, splity, self.mint, self.maxt
)
bbox2 = BoundingBox(
self.minx, self.maxx, splity, self.maxy, self.mint, self.maxt
)
return bbox1, bbox2
class Executable:
"""Command-line executable.
.. versionadded:: 0.6
"""
def __init__(self, name: Path) -> None:
"""Initialize a new Executable instance.
Args:
name: Command name.
"""
self.name = name
def __call__(self, *args: Any, **kwargs: Any) -> subprocess.CompletedProcess[bytes]:
"""Run the command.
Args:
args: Arguments to pass to the command.
kwargs: Keyword arguments to pass to :func:`subprocess.run`.
Returns:
The completed process.
"""
kwargs['check'] = True
return subprocess.run((self.name, *args), **kwargs)
def disambiguate_timestamp(
date_str: str | None, format: str
) -> tuple[datetime, datetime]:
"""Disambiguate partial timestamps.
TorchGeo stores the timestamp of each file in a pandas IntervalIndex. If the full
timestamp isn't known, a file could represent a range of time. For example, in the
CDL dataset, each mask spans an entire year. This method returns the maximum
possible range of timestamps that ``date_str`` could belong to. It does this by
parsing ``format`` to determine the level of precision of ``date_str``.
Args:
date_str: string representing date and time of a data point
format: format codes accepted by :meth:`datetime.datetime.strptime`
Returns:
(mint, maxt) tuple for indexing
"""
if not isinstance(date_str, str):
return pd.NaT, pd.NaT
mint = datetime.strptime(date_str, format)
format = format.replace('%%', '')
# TODO: May have issues with time zones, UTC vs. local time, and DST
# TODO: This is really tedious, is there a better way to do this?
if not any([f'%{c}' in format for c in 'yYcxG']):
# No temporal info
return pd.Timestamp.min, pd.Timestamp.max
elif not any([f'%{c}' in format for c in 'bBmjUWcxV']):
# Year resolution
maxt = datetime(mint.year + 1, 1, 1)
elif not any([f'%{c}' in format for c in 'aAwdjcxV']):
# Month resolution
if mint.month == 12:
maxt = datetime(mint.year + 1, 1, 1)
else:
maxt = datetime(mint.year, mint.month + 1, 1)
elif not any([f'%{c}' in format for c in 'HIcX']):
# Day resolution
maxt = mint + timedelta(days=1)
elif not any([f'%{c}' in format for c in 'McX']):
# Hour resolution
maxt = mint + timedelta(hours=1)
elif not any([f'%{c}' in format for c in 'ScX']):
# Minute resolution
maxt = mint + timedelta(minutes=1)
elif not any([f'%{c}' in format for c in 'f']):
# Second resolution
maxt = mint + timedelta(seconds=1)
else:
# Microsecond resolution
maxt = mint + timedelta(microseconds=1)
maxt -= timedelta(microseconds=1)
return mint, maxt
@contextlib.contextmanager
def working_dir(dirname: Path, create: bool = False) -> Iterator[None]:
"""Context manager for changing directories.
Args:
dirname: directory to temporarily change to
create: if True, create the destination directory
"""
if create:
os.makedirs(dirname, exist_ok=True)
cwd = os.getcwd()
os.chdir(dirname)
try:
yield
finally:
os.chdir(cwd)
def _list_dict_to_dict_list(
samples: Iterable[Mapping[Any, Any]],
) -> dict[Any, list[Any]]:
"""Convert a list of dictionaries to a dictionary of lists.
Args:
samples: a list of dictionaries
Returns:
a dictionary of lists
.. versionadded:: 0.2
"""
collated: dict[Any, list[Any]] = dict()
for sample in samples:
for key, value in sample.items():
if key not in collated:
collated[key] = []
collated[key].append(value)
return collated
def _dict_list_to_list_dict(
sample: Mapping[Any, Sequence[Any]],
) -> list[dict[Any, Any]]:
"""Convert a dictionary of lists to a list of dictionaries.
Args:
sample: a dictionary of lists
Returns:
a list of dictionaries
.. versionadded:: 0.2
"""
uncollated: list[dict[Any, Any]] = [
{} for _ in range(max(map(len, sample.values())))
]
for key, values in sample.items():
for i, value in enumerate(values):
uncollated[i][key] = value
return uncollated
[docs]def stack_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
"""Stack a list of samples along a new axis.
Useful for forming a mini-batch of samples to pass to
:class:`torch.utils.data.DataLoader`.
Args:
samples: list of samples
Returns:
a single sample
.. versionadded:: 0.2
"""
collated: dict[Any, Any] = _list_dict_to_dict_list(samples)
for key, value in collated.items():
if isinstance(value[0], Tensor):
collated[key] = torch.stack(value)
return collated
[docs]def concat_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
"""Concatenate a list of samples along an existing axis.
Useful for joining samples in a :class:`torchgeo.datasets.IntersectionDataset`.
Args:
samples: list of samples
Returns:
a single sample
.. versionadded:: 0.2
"""
collated: dict[Any, Any] = _list_dict_to_dict_list(samples)
for key, value in collated.items():
if isinstance(value[0], Tensor):
collated[key] = torch.cat(value)
else:
collated[key] = value[0]
return collated
[docs]def merge_samples(samples: Iterable[Mapping[Any, Any]]) -> dict[Any, Any]:
"""Merge a list of samples.
Useful for joining samples in a :class:`torchgeo.datasets.UnionDataset`.
Args:
samples: list of samples
Returns:
a single sample
.. versionadded:: 0.2
"""
collated: dict[Any, Any] = {}
for sample in samples:
for key, value in sample.items():
if key in collated and isinstance(value, Tensor):
# Take the maximum so that nodata values (zeros) get replaced
# by data values whenever possible
collated[key] = torch.maximum(collated[key], value)
else:
collated[key] = value
return collated
[docs]def unbind_samples(sample: MutableMapping[Any, Any]) -> list[dict[Any, Any]]:
"""Reverse of :func:`stack_samples`.
Useful for turning a mini-batch of samples into a list of samples. These individual
samples can then be plotted using a dataset's ``plot`` method.
Args:
sample: a mini-batch of samples
Returns:
list of samples
.. versionadded:: 0.2
"""
for key, values in sample.items():
if isinstance(values, Tensor):
sample[key] = torch.unbind(values)
return _dict_list_to_list_dict(sample)
def rasterio_loader(path: Path) -> np.typing.NDArray[np.int_]:
"""Load an image file using rasterio.
Args:
path: path to the image to be loaded
Returns:
the image
"""
with rasterio.open(path) as f:
array: np.typing.NDArray[np.int_] = f.read().astype(np.int32)
# NonGeoClassificationDataset expects images returned with channels last (HWC)
array = array.transpose(1, 2, 0)
return array
def sort_sentinel2_bands(x: Path) -> str:
"""Sort Sentinel-2 band files in the correct order."""
x = os.path.basename(x).split('_')[-1]
x = os.path.splitext(x)[0]
if x == 'B8A':
x = 'B08A'
return x
def draw_semantic_segmentation_masks(
image: Tensor,
mask: Tensor,
alpha: float = 0.5,
colors: Sequence[str | tuple[int, int, int]] | None = None,
) -> np.typing.NDArray[np.uint8]:
"""Overlay a semantic segmentation mask onto an image.
Args:
image: tensor of shape (3, h, w) and dtype uint8
mask: tensor of shape (h, w) with pixel values representing the classes and
dtype bool
alpha: alpha blend factor
colors: list of RGB int tuples, or color strings e.g. red, #FF00FF
Returns:
a version of ``image`` overlaid with the colors given by ``mask`` and
``colors``
"""
classes = torch.from_numpy(np.arange(len(colors) if colors else 0, dtype=np.uint8))
class_masks = mask == classes[:, None, None]
img = draw_segmentation_masks(
image=image.byte(), masks=class_masks, alpha=alpha, colors=colors
)
img = img.permute((1, 2, 0)).numpy().astype(np.uint8)
return cast('np.typing.NDArray[np.uint8]', img)
def rgb_to_mask(
rgb: np.typing.NDArray[np.uint8], colors: Sequence[tuple[int, int, int]]
) -> np.typing.NDArray[np.uint8]:
"""Converts an RGB colormap mask to a integer mask.
Args:
rgb: array mask of coded with RGB tuples
colors: list of RGB tuples to convert to integer indices
Returns:
integer array mask
"""
assert len(colors) <= 256 # we currently return a uint8 array, so the largest value
# we can map is 255
h, w = rgb.shape[:2]
mask: np.typing.NDArray[np.uint8] = np.zeros(shape=(h, w), dtype=np.uint8)
for i, c in enumerate(colors):
cmask = rgb == c
# Only update mask if class is present in mask
if isinstance(cmask, np.ndarray):
mask[cmask.all(axis=-1)] = i
return mask
def percentile_normalization(
img: np.typing.NDArray[np.int_],
lower: float = 2,
upper: float = 98,
axis: int | Sequence[int] | None = None,
) -> np.typing.NDArray[np.int_]:
"""Applies percentile normalization to an input image.
Specifically, this will rescale the values in the input such that values <= the
lower percentile value will be 0 and values >= the upper percentile value will be 1.
Using the 2nd and 98th percentile usually results in good visualizations.
Args:
img: image to normalize
lower: lower percentile in range [0,100]
upper: upper percentile in range [0,100]
axis: Axis or axes along which the percentiles are computed. The default
is to compute the percentile(s) along a flattened version of the array.
Returns:
normalized version of ``img``
.. versionadded:: 0.2
"""
assert lower < upper
lower_percentile = np.percentile(img, lower, axis=axis)
upper_percentile = np.percentile(img, upper, axis=axis)
img_normalized: np.typing.NDArray[np.int_] = np.clip(
(img - lower_percentile) / (upper_percentile - lower_percentile + 1e-5), 0, 1
)
return img_normalized
def path_is_vsi(path: Path) -> bool:
"""Checks if the given path is pointing to a Virtual File System.
.. note::
Does not check if the path exists, or if it is a dir or file.
VSI can for instance be Cloud Storage Blobs or zip-archives.
They will start with a prefix indicating this.
For examples of these, see references for the two accepted syntaxes.
* https://gdal.org/user/virtual_file_systems.html
* https://rasterio.readthedocs.io/en/latest/topics/datasets.html
Args:
path: a directory or file
Returns:
True if path is on a virtual file system, else False
.. versionadded:: 0.6
"""
return '://' in str(path) or str(path).startswith('/vsi')
def array_to_tensor(array: np.typing.NDArray[Any]) -> Tensor:
"""Converts a :class:`numpy.ndarray` to :class:`torch.Tensor`.
:func:`torch.from_tensor` rejects numpy types like uint16 that are not supported
in pytorch. This function instead casts uint16 and uint32 numpy arrays to an
appropriate pytorch type without loss of precision.
For example, a uint32 array becomes an int64 tensor. uint64 arrays will continue
to raise errors since there is no suitable torch dtype.
The returned tensor is a copy.
Args:
array: a :class:`numpy.ndarray`.
Returns:
A :class:`torch.Tensor` with the same dtype as array unless array is uint16 or
uint32, in which case an int32 or int64 Tensor is returned, respectively.
.. versionadded:: 0.6
"""
if array.dtype == np.uint16:
array = array.astype(np.int32)
elif array.dtype == np.uint32:
array = array.astype(np.int64)
return torch.tensor(array)
def lazy_import(name: str) -> Any:
"""Lazy import of *name*.
Args:
name: Name of module to import.
Returns:
Module import.
Raises:
DependencyNotFoundError: If *name* is not installed.
.. versionadded:: 0.6
"""
try:
return importlib.import_module(name)
except ModuleNotFoundError:
# Map from import name to package name on PyPI
name = name.split('.')[0].replace('_', '-')
module_to_pypi: dict[str, str] = collections.defaultdict(lambda: name)
module_to_pypi |= {'cv2': 'opencv-python', 'skimage': 'scikit-image'}
name = module_to_pypi[name]
msg = f"""\
{name} is not installed and is required to use this feature. Either run:
$ pip install {name}
to install just this dependency, or:
$ pip install torchgeo[datasets,models]
to install all optional dependencies."""
raise DependencyNotFoundError(msg) from None
def which(name: Path) -> Executable:
"""Search for executable *name*.
Args:
name: Name of executable to search for.
Returns:
Callable executable instance.
Raises:
DependencyNotFoundError: If *name* is not installed.
.. versionadded:: 0.6
"""
if cmd := shutil.which(name):
return Executable(cmd)
else:
msg = f'{name} is not installed and is required to use this dataset.'
raise DependencyNotFoundError(msg) from None
def convert_poly_coords(
geom: shapely.geometry.shape, affine_obj: Affine, inverse: bool = False
) -> shapely.geometry.shape:
"""Convert geocoordinates to pixel coordinates and vice versa, based on `affine_obj`.
Args:
geom: shapely.geometry.shape to convert
affine_obj: rasterio.Affine object to use for geoconversion
inverse: If true, convert geocoordinates to pixel coordinates
Returns:
input shape converted to pixel coordinates
.. versionadded:: 0.8
"""
if inverse:
affine_obj = ~affine_obj
xformed_shape = shapely.affinity.affine_transform(
geom,
[
affine_obj.a,
affine_obj.b,
affine_obj.d,
affine_obj.e,
affine_obj.xoff,
affine_obj.yoff,
],
)
return xformed_shape