{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "16421d50-8d7a-4972-b06f-160fd890cc86", "metadata": {}, "outputs": [], "source": [ "# Copyright (c) TorchGeo Contributors. All rights reserved.\n", "# Licensed under the MIT License." ] }, { "cell_type": "markdown", "id": "e563313d", "metadata": {}, "source": [ "# Command-Line Interface\n", "\n", "_Written by: Adam J. Stewart_\n", "\n", "TorchGeo provides a command-line interface based on [LightningCLI](https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.cli.LightningCLI.html) that allows users to combine our data modules and trainers from the comfort of the command line. This no-code solution can be attractive for both beginners and experts, as it offers flexibility and reproducibility. In this tutorial, we demonstrate some of the features of this interface." ] }, { "cell_type": "markdown", "id": "8c1f4156", "metadata": {}, "source": [ "## Setup\n", "\n", "First, we install TorchGeo. In addition to the Python library, this also installs a `torchgeo` executable." ] }, { "cell_type": "code", "execution_count": null, "id": "3f0d31a8", "metadata": {}, "outputs": [], "source": [ "%pip install torchgeo" ] }, { "cell_type": "markdown", "id": "7801ab8b-0ee3-40ac-88c2-4bdc29bb4e1b", "metadata": {}, "source": [ "## Subcommands\n", "\n", "The `torchgeo` command has a number of *subcommands* that can be run. The `--help` flag can be used to list them." ] }, { "cell_type": "code", "execution_count": null, "id": "a6ccac4e-7f20-4aa8-b851-27234ffd259f", "metadata": {}, "outputs": [], "source": [ "!torchgeo --help" ] }, { "cell_type": "markdown", "id": "19ee017d-0d8f-41c6-8e7c-68495c7e62b6", "metadata": {}, "source": [ "## Trainer\n", "\n", "Below, we run `--help` on the `fit` subcommand to see what options are available to us. `fit` is used to train and validate a model, and we can customize many aspects of the training process." ] }, { "cell_type": "code", "execution_count": null, "id": "afe1dc9d-4cee-43b0-ae30-200c64d3401a", "metadata": {}, "outputs": [], "source": [ "!torchgeo fit --help" ] }, { "cell_type": "markdown", "id": "b437860c-b406-4150-b30b-8aa895eebfcd", "metadata": {}, "source": [ "## Model\n", "\n", "We must first select an `nn.Module` model architecture to train and a `lightning.pytorch.LightningModule` trainer to train it. We will experiment with the `ClassificationTask` trainer and see what options we can customize. Any of TorchGeo's builtin trainers, or trainers written by the user, can be used in this way." ] }, { "cell_type": "code", "execution_count": null, "id": "7cd9bbd0-17c9-4e87-b10d-ea846c39bc24", "metadata": {}, "outputs": [], "source": [ "!torchgeo fit --model.help ClassificationTask" ] }, { "cell_type": "markdown", "id": "3daacd8d-64f4-4357-bdf3-759295a14224", "metadata": {}, "source": [ "## Data\n", "\n", "We must also select a `Dataset` we would like to train on and a `lightning.pytorch.LightningDataModule` we can use to access the train/val/test split and any augmentations to apply to the data. Similarly, we use the `--help` flag to see what options are available for the `EuroSAT100` dataset." ] }, { "cell_type": "code", "execution_count": null, "id": "136eb59f-6662-44af-82e9-c55bdb3f17ac", "metadata": {}, "outputs": [], "source": [ "!torchgeo fit --data.help EuroSAT100DataModule" ] }, { "cell_type": "markdown", "id": "8039cb67-ee18-4b41-8bf5-0e939493f5bb", "metadata": {}, "source": [ "## Config\n", "\n", "Now that we have seen all important configuration options, we can put them together in a YAML file. LightingCLI supports YAML, JSON, and command-line configuration. While we will write this file using Python in this tutorial, normally this file would be written in your favorite text editor." ] }, { "cell_type": "code", "execution_count": null, "id": "e25c8efb-ed8c-4795-862c-bfb84cc84e1f", "metadata": {}, "outputs": [], "source": [ "import os\n", "import tempfile\n", "\n", "root = os.path.join(tempfile.gettempdir(), 'eurosat100')\n", "config = f\"\"\"\n", "trainer:\n", " max_epochs: 1\n", " default_root_dir: '{root}'\n", "model:\n", " class_path: ClassificationTask\n", " init_args:\n", " model: 'resnet18'\n", " in_channels: 13\n", " num_classes: 10\n", "data:\n", " class_path: EuroSAT100DataModule\n", " init_args:\n", " batch_size: 8\n", " dict_kwargs:\n", " root: '{root}'\n", " download: true\n", "\"\"\"\n", "os.makedirs(root, exist_ok=True)\n", "with open(os.path.join(root, 'config.yaml'), 'w') as f:\n", " f.write(config)" ] }, { "cell_type": "markdown", "id": "a661b8d7-2dc9-4a30-8842-bd52d130e080", "metadata": {}, "source": [ "This YAML file has three sections:\n", "\n", "* trainer: Arguments to pass to the [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html)\n", "* model: Arguments to pass to the task\n", "* data: Arguments to pass to the data module\n", "\n", "The `class_path` gives the class to instantiate, `init_args` lists standard arguments, and `dict_kwargs` lists keyword arguments." ] }, { "cell_type": "markdown", "id": "e132f933-4edf-42bb-b585-e0d8ceb65eab", "metadata": {}, "source": [ "## Training\n", "\n", "We can now train our model like so." ] }, { "cell_type": "code", "execution_count": null, "id": "f84b0739-c9e7-4057-8864-98ab69a11f64", "metadata": {}, "outputs": [], "source": [ "!torchgeo fit --config {root}/config.yaml" ] }, { "cell_type": "markdown", "id": "cb1557f1-6cc0-46da-909c-836911acb248", "metadata": {}, "source": [ "## Validation\n", "\n", "Now that we have a trained model, we can evaluate performance on the validation set. Note that we need to explicitly pass in the location of the checkpoint from the previous run." ] }, { "cell_type": "code", "execution_count": null, "id": "b9cbb4f4-1879-4ae7-bae4-2c24d49a4a61", "metadata": {}, "outputs": [], "source": [ "import glob\n", "\n", "checkpoint = glob.glob(\n", " os.path.join(root, 'lightning_logs', 'version_0', 'checkpoints', '*.ckpt')\n", ")[0]\n", "\n", "!torchgeo validate --config {root}/config.yaml --ckpt_path {checkpoint}" ] }, { "cell_type": "markdown", "id": "ba816fc3-5cac-4cbc-a6ef-effc6c9faa61", "metadata": {}, "source": [ "## Testing\n", "\n", "After finishing our hyperparameter tuning, we can calculate and report the final test performance." ] }, { "cell_type": "code", "execution_count": null, "id": "f1faa997-9f81-4847-94fc-5a8bb7687369", "metadata": {}, "outputs": [], "source": [ "!torchgeo test --config {root}/config.yaml --ckpt_path {checkpoint}" ] }, { "cell_type": "markdown", "id": "f5383d30-8f76-44a2-8366-e6fcbd1e6042", "metadata": {}, "source": [ "## Additional Reading\n", "\n", "Lightning CLI has many more features that are worth learning. You can learn more by reading the following set of tutorials:\n", "\n", "* [Configure hyperparameters from the CLI](https://lightning.ai/docs/pytorch/stable/cli/lightning_cli.html)" ] } ], "metadata": { "accelerator": "GPU", "execution": { "timeout": 1200 }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.0" } }, "nbformat": 4, "nbformat_minor": 5 }