Shortcuts
Open in Studio Open in Colab
[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

Command-Line Interface

Written by: Adam J. Stewart

TorchGeo provides a command-line interface based on LightningCLI that allows users to combine our data modules and trainers from the comfort of the command line. This no-code solution can be attractive for both beginners and experts, as it offers flexibility and reproducibility. In this tutorial, we demonstrate some of the features of this interface.

Setup

First, we install TorchGeo. In addition to the Python library, this also installs a torchgeo executable.

[ ]:
%pip install torchgeo

Subcommands

The torchgeo command has a number of subcommands that can be run. The --help flag can be used to list them.

[ ]:
!torchgeo --help

Trainer

Below, we run --help on the fit subcommand to see what options are available to us. fit is used to train and validate a model, and we can customize many aspects of the training process.

[ ]:
!torchgeo fit --help

Model

We must first select an nn.Module model architecture to train and a lightning.pytorch.LightningModule trainer to train it. We will experiment with the ClassificationTask trainer and see what options we can customize. Any of TorchGeo’s builtin trainers, or trainers written by the user, can be used in this way.

[ ]:
!torchgeo fit --model.help ClassificationTask

Data

We must also select a Dataset we would like to train on and a lightning.pytorch.LightningDataModule we can use to access the train/val/test split and any augmentations to apply to the data. Similarly, we use the --help flag to see what options are available for the EuroSAT100 dataset.

[ ]:
!torchgeo fit --data.help EuroSAT100DataModule

Config

Now that we have seen all important configuration options, we can put them together in a YAML file. LightingCLI supports YAML, JSON, and command-line configuration. While we will write this file using Python in this tutorial, normally this file would be written in your favorite text editor.

[ ]:
import os
import tempfile

root = os.path.join(tempfile.gettempdir(), 'eurosat100')
config = f"""
trainer:
  max_epochs: 1
  default_root_dir: '{root}'
model:
  class_path: ClassificationTask
  init_args:
    model: 'resnet18'
    in_channels: 13
    num_classes: 10
data:
  class_path: EuroSAT100DataModule
  init_args:
    batch_size: 8
  dict_kwargs:
    root: '{root}'
    download: true
"""
os.makedirs(root, exist_ok=True)
with open(os.path.join(root, 'config.yaml'), 'w') as f:
    f.write(config)

This YAML file has three sections:

  • trainer: Arguments to pass to the Trainer

  • model: Arguments to pass to the task

  • data: Arguments to pass to the data module

The class_path gives the class to instantiate, init_args lists standard arguments, and dict_kwargs lists keyword arguments.

Training

We can now train our model like so.

[ ]:
!torchgeo fit --config {root}/config.yaml

Validation

Now that we have a trained model, we can evaluate performance on the validation set. Note that we need to explicitly pass in the location of the checkpoint from the previous run.

[ ]:
import glob

checkpoint = glob.glob(
    os.path.join(root, 'lightning_logs', 'version_0', 'checkpoints', '*.ckpt')
)[0]

!torchgeo validate --config {root}/config.yaml --ckpt_path {checkpoint}

Testing

After finishing our hyperparameter tuning, we can calculate and report the final test performance.

[ ]:
!torchgeo test --config {root}/config.yaml --ckpt_path {checkpoint}

Additional Reading

Lightning CLI has many more features that are worth learning. You can learn more by reading the following set of tutorials:

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources