[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
Command-Line Interface¶
Written by: Adam J. Stewart
TorchGeo provides a command-line interface based on LightningCLI that allows users to combine our data modules and trainers from the comfort of the command line. This no-code solution can be attractive for both beginners and experts, as it offers flexibility and reproducibility. In this tutorial, we demonstrate some of the features of this interface.
Setup¶
First, we install TorchGeo. In addition to the Python library, this also installs a torchgeo
executable.
[ ]:
%pip install torchgeo
Subcommands¶
The torchgeo
command has a number of subcommands that can be run. The --help
flag can be used to list them.
[ ]:
!torchgeo --help
Trainer¶
Below, we run --help
on the fit
subcommand to see what options are available to us. fit
is used to train and validate a model, and we can customize many aspects of the training process.
[ ]:
!torchgeo fit --help
Model¶
We must first select an nn.Module
model architecture to train and a lightning.pytorch.LightningModule
trainer to train it. We will experiment with the ClassificationTask
trainer and see what options we can customize. Any of TorchGeo’s builtin trainers, or trainers written by the user, can be used in this way.
[ ]:
!torchgeo fit --model.help ClassificationTask
Data¶
We must also select a Dataset
we would like to train on and a lightning.pytorch.LightningDataModule
we can use to access the train/val/test split and any augmentations to apply to the data. Similarly, we use the --help
flag to see what options are available for the EuroSAT100
dataset.
[ ]:
!torchgeo fit --data.help EuroSAT100DataModule
Config¶
Now that we have seen all important configuration options, we can put them together in a YAML file. LightingCLI supports YAML, JSON, and command-line configuration. While we will write this file using Python in this tutorial, normally this file would be written in your favorite text editor.
[ ]:
import os
import tempfile
root = os.path.join(tempfile.gettempdir(), 'eurosat100')
config = f"""
trainer:
max_epochs: 1
default_root_dir: '{root}'
model:
class_path: ClassificationTask
init_args:
model: 'resnet18'
in_channels: 13
num_classes: 10
data:
class_path: EuroSAT100DataModule
init_args:
batch_size: 8
dict_kwargs:
root: '{root}'
download: true
"""
os.makedirs(root, exist_ok=True)
with open(os.path.join(root, 'config.yaml'), 'w') as f:
f.write(config)
This YAML file has three sections:
trainer: Arguments to pass to the Trainer
model: Arguments to pass to the task
data: Arguments to pass to the data module
The class_path
gives the class to instantiate, init_args
lists standard arguments, and dict_kwargs
lists keyword arguments.
Training¶
We can now train our model like so.
[ ]:
!torchgeo fit --config {root}/config.yaml
Validation¶
Now that we have a trained model, we can evaluate performance on the validation set. Note that we need to explicitly pass in the location of the checkpoint from the previous run.
[ ]:
import glob
checkpoint = glob.glob(
os.path.join(root, 'lightning_logs', 'version_0', 'checkpoints', '*.ckpt')
)[0]
!torchgeo validate --config {root}/config.yaml --ckpt_path {checkpoint}
Testing¶
After finishing our hyperparameter tuning, we can calculate and report the final test performance.
[ ]:
!torchgeo test --config {root}/config.yaml --ckpt_path {checkpoint}
Additional Reading¶
Lightning CLI has many more features that are worth learning. You can learn more by reading the following set of tutorials: