Shortcuts

Source code for torchgeo.datasets.dota

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""DOTA dataset."""

import os
from collections.abc import Callable
from typing import Any, ClassVar, Literal

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from matplotlib import patches
from matplotlib.figure import Figure
from PIL import Image
from torch import Tensor

from .errors import DatasetNotFoundError
from .geo import NonGeoDataset
from .utils import (
    Path,
    check_integrity,
    download_url,
    extract_archive,
    percentile_normalization,
)


[docs]class DOTA(NonGeoDataset): """DOTA dataset. `DOTA <https://captain-whu.github.io/DOTA/index.html>`__ is a large-scale object detection dataset for aerial imagery containing RGB and gray-scale imagery from Google Earth, GF-2 and JL-1 satellites as well as additional aerial imagery from CycloMedia. There are three versions of the dataset: v1.0, v1.5, and v2.0, where, v1.0 and v1.5 have the same images but different annotations, and v2.0 extends both the images and annotations with more samples Dataset features: * 1869 samples in v1.0 and v1.5 and 2423 samples in v2.0 * multi-class object detection (15 classes in v1.0 and v1.5 and 18 classes in v2.0) * horizontal and oriented bounding boxes Dataset format: * images are three channel PNGs with various pixel sizes * annotations are text files with one line per bounding box Classes: 0. plane 1. ship 2. storage-tank 3. baseball-diamond 4. tennis-court 5. basketball-court 6. ground-track-field 7. harbor 8. bridge 9. large-vehicle 10. small-vehicle 11. helicopter 12. roundabout 13. soccer-ball-field 14. swimming-pool 15. container-crane (v1.5+) 16. airport (v2.0+) 17. helipad (v2.0+) If you use this work in your research, please cite the following papers: * https://arxiv.org/abs/2102.12219 * https://arxiv.org/abs/1711.10398 .. versionadded:: 0.7 """ url = 'https://hf.co/datasets/isaaccorley/dota/resolve/672e63236622f7da6ee37fca44c50ac368b77cab/{}' file_info: ClassVar[dict[str, dict[str, dict[str, dict[str, str]]]]] = { 'train': { 'images': { '1.0': { 'filename': 'dotav1.0_images_train.tar.gz', 'md5': '363b472dc3c71e7fa2f4a60223b437ea', }, '1.5': { 'filename': 'dotav1.0_images_train.tar.gz', 'md5': '363b472dc3c71e7fa2f4a60223b437ea', }, '2.0': { 'filename': 'dotav2.0_images_train.tar.gz', 'md5': '91ae5212d170330ab9f65ccb6c675763', }, }, 'annotations': { '1.0': { 'filename': 'dotav1.0_annotations_train.tar.gz', 'md5': 'f6788257bcc4d29018344a4128e3734a', }, '1.5': { 'filename': 'dotav1.5_annotations_train.tar.gz', 'md5': '0da97e5623a87d7bec22e75f6978dbce', }, '2.0': { 'filename': 'dotav2.0_annotations_train.tar.gz', 'md5': '04d3d626df2203053b7f06581b3b0667', }, }, }, 'val': { 'images': { '1.0': { 'filename': 'dotav1.0_images_val.tar.gz', 'md5': '42293219ba61d61c417ae558bbe1f2ba', }, '1.5': { 'filename': 'dotav1.0_images_val.tar.gz', 'md5': '42293219ba61d61c417ae558bbe1f2ba', }, '2.0': { 'filename': 'dotav2.0_images_val.tar.gz', 'md5': '737f65edf54b5aa627b3d48b0e253095', }, }, 'annotations': { '1.0': { 'filename': 'dotav1.0_annotations_val.tar.gz', 'md5': '28155c05b1dc3a0f5cb6b9bdfef85a13', }, '1.5': { 'filename': 'dotav1.5_annotations_val.tar.gz', 'md5': '85bf945788784cf9b4f1c714453178fc', }, '2.0': { 'filename': 'dotav2.0_annotations_val.tar.gz', 'md5': 'ec53c1dbcfc125d7532bd6a065c647ac', }, }, }, } sample_df_path = 'samples.csv' classes = ( 'plane', 'ship', 'storage-tank', 'baseball-diamond', 'tennis-court', 'basketball-court', 'ground-track-field', 'harbor', 'bridge', 'large-vehicle', 'small-vehicle', 'helicopter', 'roundabout', 'soccer-ball-field', 'swimming-pool', 'container-crane', 'airport', 'helipad', ) valid_splits = ('train', 'val') valid_versions = ('1.0', '1.5', '2.0') valid_orientations = ('horizontal', 'oriented')
[docs] def __init__( self, root: Path = 'data', split: Literal['train', 'val'] = 'train', version: Literal['1.0', '1.5', '2.0'] = '2.0', bbox_orientation: Literal['horizontal', 'oriented'] = 'oriented', transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None, download: bool = False, checksum: bool = False, ) -> None: """Initialize a new DOTA dataset instance. Args: root: root directory where dataset can be found split: split of the dataset to use, one of ['train', 'val'] version: version of the dataset to use, one of ['1.0', '1.5', '2.0'] bbox_orientation: bounding box orientation, one of ['horizontal', 'oriented'], where horizontal returnx xyxy format and oriented returns x1y1x2y2x3y3x4y4 format transforms: a function/transform that takes input sample and its target as entry and returns a transformed version download: if True, download dataset and store it in the root directory checksum: if True, check the MD5 of the downloaded files (may be slow) Raises: AssertionError: if *split*, *version*, or *bbox_orientation* argument are not valid DatasetNotFoundError: if dataset is not found or corrupted, and *download* is False """ assert split in self.valid_splits, ( f"Split '{split}' not supported, use one of {self.valid_splits}" ) assert version in self.valid_versions, ( f"Version '{version}' not supported, use one of {self.valid_versions}" ) assert bbox_orientation in self.valid_orientations, ( f'Bounding box orientation must be one of {self.valid_orientations}' ) self.root = root self.split = split self.version = version self.transforms = transforms self.download = download self.checksum = checksum self.bbox_orientation = bbox_orientation self._verify() self.sample_df = pd.read_csv(os.path.join(self.root, 'samples.csv')) self.sample_df['version'] = self.sample_df['version'].astype(str) self.sample_df = self.sample_df[self.sample_df['split'] == self.split] self.sample_df = self.sample_df[ self.sample_df['version'] == self.version ].reset_index(drop=True)
[docs] def __len__(self) -> int: """Return the number of samples in the dataset. Returns: length of the dataset """ return len(self.sample_df)
[docs] def __getitem__(self, index: int) -> dict[str, Any]: """Return an index within the dataset. Args: index: index to return Returns: data and label at that index """ sample_row = self.sample_df.iloc[index] sample = {'image': self._load_image(sample_row['image_path'])} boxes, labels = self._load_annotations(sample_row['annotation_path']) if self.bbox_orientation == 'horizontal': sample['bbox_xyxy'] = boxes else: sample['bbox'] = boxes sample['labels'] = labels if self.transforms is not None: sample = self.transforms(sample) return sample
def _load_image(self, path: str) -> Tensor: """Load image. Args: path: path to image file Returns: image: image tensor """ image = Image.open(os.path.join(self.root, path)).convert('RGB') return torch.from_numpy(np.array(image).transpose(2, 0, 1)).float() def _load_annotations(self, path: str) -> tuple[Tensor, Tensor]: """Load DOTA annotations from text file. Format: x1 y1 x2 y2 x3 y3 x4 y4 class difficult Some files have 2 header lines that need to be skipped: imagesource:GoogleEarth gsd:0.146343590398 Args: path: path to annotation file Returns: tuple of: boxes: tensor of shape (N, 8) with coordinates for oriented and (N, 4) for horizontal labels: tensor of shape (N,) with class indices """ with open(os.path.join(self.root, path)) as f: lines = f.readlines() # Skip header if present start_idx = 0 if lines and lines[0].startswith('imagesource'): start_idx = 2 boxes = [] labels = [] for line in lines[start_idx:]: parts = line.strip().split(' ') # Always read 8 coordinates coords = [float(p) for p in parts[:8]] label = parts[8] labels.append(self.classes.index(label)) if self.bbox_orientation == 'horizontal': # Convert to [xmin, ymin, xmax, ymax] format x_coords = coords[::2] # even indices (0,2,4,6) y_coords = coords[1::2] # odd indices (1,3,5,7) xmin, xmax = min(x_coords), max(x_coords) ymin, ymax = min(y_coords), max(y_coords) boxes.append([xmin, ymin, xmax, ymax]) else: boxes.append(coords) if not boxes: return ( torch.zeros((0, 4 if self.bbox_orientation == 'horizontal' else 8)), torch.zeros(0, dtype=torch.long), ) else: return torch.tensor(boxes), torch.tensor(labels) def _verify(self) -> None: """Verify dataset integrity and download/extract if needed.""" # check if directories and sample file are present required_dirs = [ os.path.join(self.root, self.split, 'images'), os.path.join( self.root, self.split, 'annotations', f'version{self.version}' ), os.path.join(self.root, self.sample_df_path), ] if all(os.path.exists(d) for d in required_dirs): return # Check for compressed files, v1.0 and v1.5 have the same images but different annotations files_needed = [ ( self.file_info[self.split]['images'][self.version]['filename'], self.file_info[self.split]['images'][self.version]['md5'], ), ( self.file_info[self.split]['annotations'][self.version]['filename'], self.file_info[self.split]['annotations'][self.version]['md5'], ), ] # For v2.0, also need v1.0 image files, but only v2 annotations if self.version == '2.0': files_needed.append( ( self.file_info[self.split]['images']['1.0']['filename'], self.file_info[self.split]['images']['1.0']['md5'], ) ) # Check if archives exist and verify checksums if requested exists = [] for filename, md5 in files_needed: filepath = os.path.join(self.root, filename) if os.path.exists(filepath): if self.checksum: if not check_integrity(filepath, md5): raise RuntimeError(f'Archive {filename} corrupted') exists.append(True) self._extract([(filename, md5)]) else: exists.append(False) if all(exists): return if not self.download: raise DatasetNotFoundError(self) # also download the metadata file self._download(files_needed) self._extract(files_needed) def _download(self, files_needed: list[tuple[str, str]]) -> None: """Download the dataset. Args: files_needed: list of files to download for the particular version """ for filename, md5 in files_needed: if not os.path.exists(os.path.join(self.root, filename)): download_url( url=self.url.format(filename), root=self.root, filename=filename, md5=None if not self.checksum else md5, ) if not os.path.exists(os.path.join(self.root, self.sample_df_path)): download_url( url=self.url.format(self.sample_df_path), root=self.root, filename=self.sample_df_path, ) def _extract(self, files_needed: list[tuple[str, str]]) -> None: """Extract the dataset. Args: files_needed: list of files to extract for the particular version """ for filename, _ in files_needed: filepath = os.path.join(self.root, filename) extract_archive(filepath, self.root)
[docs] def plot( self, sample: dict[str, Tensor], show_titles: bool = True, suptitle: str | None = None, box_alpha: float = 0.7, ) -> Figure: """Plot a sample from the dataset. Args: sample: a sample returned by __getitem__ show_titles: flag indicating whether to show titles suptitle: optional string to use as a suptitle box_alpha: alpha value for boxes Returns: a matplotlib Figure with the rendered sample """ image = percentile_normalization(sample['image'].permute(1, 2, 0).numpy()) if self.bbox_orientation == 'horizontal': boxes = sample['bbox_xyxy'].cpu().numpy() else: boxes = sample['bbox'].cpu().numpy() labels = sample['labels'].cpu().numpy() fig, ax = plt.subplots(figsize=(10, 10)) ax.imshow(image) ax.axis('off') # Create color map for classes cm = plt.get_cmap('gist_rainbow') for box, label_idx in zip(boxes, labels): color = cm(label_idx / len(self.classes)) label = self.classes[label_idx] if self.bbox_orientation == 'horizontal': # Horizontal box: [xmin, ymin, xmax, ymax] x1, y1, x2, y2 = box rect = patches.Rectangle( (x1, y1), x2 - x1, y2 - y1, linewidth=2, alpha=box_alpha, linestyle='solid', edgecolor=color, facecolor='none', ) ax.add_patch(rect) # Add label above box ax.text( x1, y1 - 5, label, color='white', fontsize=8, bbox=dict(facecolor=color, alpha=box_alpha), ) else: # Oriented box: [x1,y1,x2,y2,x3,y3,x4,y4] vertices = box.reshape(4, 2) polygon = patches.Polygon( vertices, linewidth=2, alpha=box_alpha, linestyle='solid', edgecolor=color, facecolor='none', ) ax.add_patch(polygon) # Add label at centroid centroid_x = vertices[:, 0].mean() centroid_y = vertices[:, 1].mean() ax.text( centroid_x, centroid_y, label, color='white', fontsize=8, bbox=dict(facecolor=color, alpha=box_alpha), ha='center', va='center', ) if suptitle is not None: plt.suptitle(suptitle) return fig

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources