Source code for torchgeo.datasets.soda
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
"""SODA datasets."""
import json
import os
from collections.abc import Callable
from typing import ClassVar, Literal
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib.figure import Figure
from PIL import Image
from shapely import MultiPoint, Polygon
from torch import Tensor
from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import Path, check_integrity, download_and_extract_archive, download_url
[docs]class SODAA(NonGeoDataset):
"""SODA-A dataset.
The `SODA-A <https://shaunyuan22.github.io/SODA/>`_ dataset is a high resolution
aerial imagery dataset for small object detection.
Dataset features:
* 2513 images
* 872,069 annotations with oriented bounding boxes
* 9 classes
Dataset format:
* Images are three channel .jpg files.
* Annotations are in json files
Classes:
0. Airplane
1. Helicopter
2. Small vehicle
3. Large vehicle
4. Ship
5. Container
6. Storage tank
7. Swimming-pool
8. Windmill
9. Other
If you use this dataset in your research, please cite the following paper:
* https://ieeexplore.ieee.org/document/10168277
.. versionadded:: 0.7
"""
url = 'https://hf.co/datasets/torchgeo/soda-a/resolve/b082b9555ea9960d614b54a8cecde4cc63ec5481/{}'
files: ClassVar[dict[str, dict[str, str]]] = {
'images': {
'filename': 'Images.zip',
'md5sum': '8ee4ad7a306b0a0a900fa78a4f6aae68',
},
'labels': {
'filename': 'Annotations.zip',
'md5sum': '45b0d21209fc332d89b0144b308e57fa',
},
}
classes = (
'airplane',
'helicopter',
'small-vehicle',
'large-vehicle',
'ship',
'container',
'storage-tank',
'swimming-pool',
'windmill',
'other',
)
valid_splits = ('train', 'val', 'test')
valid_orientations = ('oriented', 'horizontal')
[docs] def __init__(
self,
root: Path = 'data',
split: Literal['train', 'val', 'test'] = 'train',
bbox_orientation: Literal['oriented', 'horizontal'] = 'horizontal',
transforms: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
download: bool = False,
checksum: bool = False,
) -> None:
"""Initialize a new instance of SODA-A dataset.
Args:
root: root directory where dataset can be found
split: one of "train", "val", or "test"
bbox_orientation: one of "oriented" or "horizontal"
transforms: a function/transform that takes input sample and its target as
entry and returns a transformed version
download: if True, download dataset and store it in the root directory
checksum: if True, check the MD5 of the downloaded files (may be slow)
Raises:
AssertionError: if *split* or *bbox_orientation* argument is invalid
DatasetNotFoundError: If dataset is not found and *download* is False.
"""
assert split in self.valid_splits, f'split must be one of {self.valid_splits}'
assert bbox_orientation in self.valid_orientations, (
f'bbox_orientation must be one of {self.valid_orientations}'
)
self.root = root
self.split = split
self.bbox_orientation = bbox_orientation
self.transforms = transforms
self.download = download
self.checksum = checksum
self._verify()
self.sample_df = pd.read_csv(os.path.join(self.root, 'sample_df.csv'))
self.sample_df = self.sample_df[
self.sample_df['split'] == self.split
].reset_index(drop=True)
[docs] def __len__(self) -> int:
"""Return the number of samples in the dataset."""
return len(self.sample_df)
[docs] def __getitem__(self, idx: int) -> dict[str, Tensor]:
"""Return the sample at the given index.
Args:
idx: index of the sample to return
Returns:
the sample at the given index
"""
row = self.sample_df.iloc[idx]
image = self._load_image(os.path.join(self.root, row['image_path']))
boxes, labels = self._load_labels(os.path.join(self.root, row['label_path']))
sample: dict[str, Tensor] = {'image': image, 'label': labels}
if self.bbox_orientation == 'oriented':
sample['bbox'] = boxes
else:
sample['bbox_xyxy'] = boxes
if self.transforms is not None:
sample = self.transforms(sample)
return sample
def _load_image(self, path: str) -> Tensor:
"""Load an image from disk.
Args:
path: path to the image file
Returns:
the image as a tensor
"""
with Image.open(path) as img:
array: np.typing.NDArray[np.int_] = np.array(img.convert('RGB'))
tensor: Tensor = torch.from_numpy(array)
# Convert from HxWxC to CxHxW
tensor = tensor.permute((2, 0, 1))
return tensor
def _load_labels(self, path: str) -> tuple[Tensor, Tensor]:
"""Load labels from disk.
Args:
path: path to the label file
Returns:
tuple of:
- boxes: tensor of bounding boxes in XYXY format [N, 4]
- labels: tensor of class labels [N]
"""
with open(path) as f:
data = json.load(f)
boxes = []
labels = []
for ann in data['annotations']:
# Extract polygon points
coords = ann['poly']
points = [
(coords[i], coords[i + 1])
for i in range(0, len(coords), 2)
if i + 1 < len(coords)
]
# Convert to axis-aligned bounding box
if self.bbox_orientation == 'horizontal':
shapely_poly = Polygon(points)
minx, miny, maxx, maxy = shapely_poly.bounds
boxes.append([minx, miny, maxx, maxy])
else: # convert to oriented bbox
hull = MultiPoint(points).convex_hull
min_rect = hull.minimum_rotated_rectangle
rect_coords = list(min_rect.exterior.coords)[:-1]
obb_coords = [coord for point in rect_coords for coord in point]
boxes.append(obb_coords)
labels.append(ann['category_id'])
boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
labels_tensor = torch.tensor(labels, dtype=torch.long)
return boxes_tensor, labels_tensor
def _verify(self) -> None:
"""Verify the integrity of the dataset."""
exists = []
df_path = os.path.join(self.root, 'sample_df.csv')
if os.path.exists(df_path):
exists.append(True)
df = pd.read_csv(df_path)
df = df[df['split'] == self.split].reset_index(drop=True)
for idx, row in df.iterrows():
image_path = os.path.join(self.root, row['image_path'])
label_path = os.path.join(self.root, row['label_path'])
exists.append(os.path.exists(image_path) and os.path.exists(label_path))
else:
exists.append(False)
if all(exists):
return
exists = []
for file in self.files.values():
archive_path = os.path.join(self.root, file['filename'])
if os.path.exists(archive_path):
if self.checksum and not check_integrity(archive_path, file['md5sum']):
raise RuntimeError('Dataset found, but corrupted.')
exists.append(True)
else:
exists.append(False)
if all(exists):
return
if not self.download:
raise DatasetNotFoundError(self)
self._download()
def _download(self) -> None:
"""Download the dataset."""
for file in self.files.values():
download_and_extract_archive(
self.url.format(file['filename']),
self.root,
filename=file['filename'],
md5=file['md5sum'] if self.checksum else None,
)
# also download the sample_df
download_url(
self.url.format('sample_df.csv'), self.root, filename='sample_df.csv'
)
[docs] def plot(
self,
sample: dict[str, Tensor],
show_titles: bool = True,
suptitle: str | None = None,
box_alpha: float = 0.7,
) -> Figure:
"""Plot a sample from the dataset with legend.
Args:
sample: a sample returned by :meth:`__getitem__`
show_titles: flag indicating whether to show titles above each panel
suptitle: optional string to use as a suptitle
box_alpha: alpha value for boxes
Returns:
a matplotlib Figure with the rendered sample
"""
image = sample['image'].permute((1, 2, 0)).numpy()
if self.bbox_orientation == 'horizontal':
boxes = sample['bbox_xyxy'].numpy()
else:
boxes = sample['bbox'].numpy()
labels = sample['label'].numpy()
fig, ax = plt.subplots(ncols=1, figsize=(10, 10))
ax.imshow(image)
ax.axis('off')
cm = plt.get_cmap('gist_rainbow')
unique_labels = set()
legend_elements = []
for box, label_idx in zip(boxes, labels):
color = cm(label_idx / len(self.classes))
label = self.classes[label_idx]
if self.bbox_orientation == 'horizontal':
# Horizontal box: [xmin, ymin, xmax, ymax]
x1, y1, x2, y2 = box
rect = patches.Rectangle(
(x1, y1),
x2 - x1,
y2 - y1,
linewidth=2,
alpha=box_alpha,
linestyle='solid',
edgecolor=color,
facecolor='none',
)
ax.add_patch(rect)
else:
# Oriented box: [x1,y1,x2,y2,x3,y3,x4,y4]
vertices = box.reshape(4, 2)
polygon = patches.Polygon(
vertices,
linewidth=2,
alpha=box_alpha,
linestyle='solid',
edgecolor=color,
facecolor='none',
)
ax.add_patch(polygon)
if label not in unique_labels:
legend_elements.append(
patches.Patch(facecolor=color, alpha=box_alpha, label=label)
)
unique_labels.add(label)
ax.legend(
handles=legend_elements,
loc='lower center',
ncol=len(legend_elements),
mode='expand',
)
if suptitle is not None:
plt.suptitle(suptitle)
return fig