Shortcuts
Open in Studio Open in Colab
[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

Introduction to PyTorch

Written by: Adam J. Stewart

In this tutorial, we introduce the basics of deep learning with PyTorch. Understanding deep learning terminology and the training and evaluation pipeline in PyTorch is essential to using TorchGeo.

Setup

First, we install TorchGeo and all of its dependencies, including PyTorch.

[ ]:
%pip install torchgeo

Imports

Next, we import PyTorch, TorchGeo, and any other libraries we need. We also manually set the random seed to ensure the reproducibility of our experiments.

[ ]:
import os
import tempfile

import kornia.augmentation as K
import torch
from torch import nn, optim
from torch.utils.data import DataLoader

from torchgeo.datasets import EuroSAT100
from torchgeo.models import ResNet18_Weights, resnet18

torch.manual_seed(0)

Definitions

If this is your first introduction to deep learning (DL), a natural question might be “what is deep learning?”. You may also be curious how it relates to other similar buzz words, including artificial intelligence (AI) and machine learning (ML). We can define these terms as follows:

  • AI: when machines exhibit human intelligence

  • ML: when machines learn from example

  • DL: when machines learn using neural networks

In this definition, DL is a subset of ML, and ML is a subset of AI. Some common examples of models and applications of these include:

  • AI: Minimax, A*, Deep Blue, video game AI

  • ML: OLS, SVM, \(k\)-means, spam filtering

  • DL: MLP, CNN, ChatGPT, self-driving cars

In this tutorial, we will specifically focus on deep learning, but many of the same concepts are shared with machine learning.

Datasets

In order to learn by example, we first need examples. In machine learning, we construct datasets of the form:

\[D = \left\{\left(x^{(i)}, y^{(i)}\right)\right\}_{i=1}^N\]

Written in English, dataset \(D\) is composed of \(N\) pairs of inputs \(x\) and expected outputs \(y\). \(x\) and \(y\) can be tabular data, images, text, or any other object that can be represented mathematically.

EuroSAT

In this tutorial (and many later tutorials), we will use EuroSAT100, a toy dataset composed of 100 images from the EuroSAT dataset. EuroSAT is a popular image classification dataset with multispectral images from the Sentinel-2 satellites. Each image is classified into one of ten categories or “classes”:

  1. Annual Crop

  2. Forest

  3. Herbaceous Vegetation

  4. Highway

  5. Industrial Buildings

  6. Pasture

  7. Permanent Crop

  8. Residential Buildings

  9. River

  10. Sea & Lake

We can load this dataset and visualize the RGB bands of some example \((x, y)\) pairs like so:

[ ]:
root = os.path.join(tempfile.gettempdir(), 'eurosat100')
dataset = EuroSAT100(root, download=True)

for i in torch.randint(len(dataset), (10,)):
    sample = dataset[i]
    dataset.plot(sample)

In machine learning, we not only want to train a model, but also evaluate its performance on unseen data. Oftentimes, our dataset is split into three separate subsets:

  • train: for training the model parameters

  • val: for validating the model hyperparameters

  • test: for testing the model performance

Parameters are the actual model weights, while hyperparameters are things like model width or learning rate that are chosen by the user. We can initialize datasets for all three splits like so:

[ ]:
train_dataset = EuroSAT100(root, split='train')
val_dataset = EuroSAT100(root, split='val')
test_dataset = EuroSAT100(root, split='test')

Data Loaders

While our dataset objects know how to load a single \((x, y)\) pair, machine learning often operates on what are called mini-batches of data. We can pass our above datasets to a PyTorch DataLoader object to construct these mini-batches:

[ ]:
batch_size = 10

train_dataloader = DataLoader(train_dataset, batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size, shuffle=False)

Transforms

There are two categories of transforms a user may want to apply to their data:

  • Preprocessing: required to make data “ML-ready”

  • Data augmentation: designed to artificially inflate the size of the dataset

Preprocessing transforms such as normalization and one-hot encodings are applied to both training and evaluation data. Data augmentation transforms such as random flip and rotation are typically only performed during training. Below, we initialize transforms for both using the Kornia library.

[ ]:
preprocess = K.Normalize(0, 10000)
augment = K.ImageSequential(K.RandomHorizontalFlip(), K.RandomVerticalFlip())

Model

Our goal is to learn some function \(f\) that can map between input \(x\) and expected output \(y\). Mathematically, this can be expressed as:

\[x \overset{f}{\mapsto} y, \quad y = f(x)\]

Since our \(x\) in this case is an image, we choose to use ResNet-18, a popular convolutional neural network (CNN). We also initialize our model with weights that have been pre-trained on Sentinel-2 imagery so we don’t have to start from scratch. This process is known as transfer learning.

[ ]:
model = resnet18(ResNet18_Weights.SENTINEL2_ALL_MOCO)

Loss Function

If \(y\) is our expected output (also called “ground truth”) and \(\hat{y}\) is our predicted output, our goal is to minimize the difference between \(y\) and \(\hat{y}\). This difference is referred to as error or loss, and the loss function tells us how big of a mistake we made. For regression tasks, a simple mean squared error is sufficient:

\[L(y, \hat{y}) = \left(y - \hat{y}\right)^2\]

For classification tasks, such as EuroSAT, we instead use a negative log-likelihood:

\[L_c(y, \hat{y}) = - \sum_{c=1}^C \mathbb{1}_{y=\hat{y}}\log{p_c}\]

where \(\mathbb{1}\) is the indicator function and \(p_c\) is the probability with which the model predicts class \(c\). By normalizing this over the log probability of all classes, we get the cross-entropy loss.

[ ]:
loss_fn = nn.CrossEntropyLoss()

Optimizer

In order to minimize our loss, we compute the gradient of the loss function with respect to model parameters \(\theta\). We then take a small step \(\alpha\) (also called the learning rate) in the direction of the negative gradient to update our model parameters in a process called backpropagation:

\[\theta \leftarrow \theta - \alpha \nabla_\theta L(y, \hat{y})\]

When done one image or one mini-batch at a time, this is known as stochastic gradient descent (SGD).

[ ]:
optimizer = optim.SGD(model.parameters(), lr=1e-2)

Device

If you peak into the internals of deep learning models, you’ll notice that most of it is actually linear algebra. This linear algebra is extremely easy to parallelize, and therefore can run very quickly on a GPU. We now transfer our model and all data to the GPU (if one is available):

[ ]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)

Training

We finally have all the basic components we need to train our ResNet-18 model on the EuroSAT100 dataset. During training, we set the model to train mode, then iterate over all mini-batches in the dataset. During the forward pass, we ask the model \(f\) to predict \(\hat{y}\) given \(x\). We then calculate the loss accrued by these predictions. During the backward pass, we backpropagate our gradients to update all model weights.

[ ]:
def train(dataloader):
    model.train()
    total_loss = 0
    for batch in dataloader:
        x = batch['image'].to(device)
        y = batch['label'].to(device)
        x = preprocess(x)
        x = augment(x)

        # Forward pass
        y_hat = model(x)
        loss = loss_fn(y_hat, y)
        total_loss += loss.item()

        # Backward pass
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(f'Loss: {total_loss:.2f}')

Evaluation

Once the model is trained, we need to evaluate its performance on unseen data. To do this, we set the model to evaluation mode, then iterate over all mini-batches in the dataset. Note that we also disable the computation of gradients, since we do not need to backpropagate them. Finally, we compute the number of correctly classified images.

[ ]:
def evaluate(dataloader):
    model.eval()
    correct = 0
    with torch.no_grad():
        for batch in dataloader:
            x = batch['image'].to(device)
            y = batch['label'].to(device)
            x = preprocess(x)

            # Forward pass
            y_hat = model(x)
            correct += (y_hat.argmax(1) == y).type(torch.float).sum().item()

    correct /= len(dataloader.dataset)
    print(f'Accuracy: {correct:.0%}')

Putting It All Together

In machine learning, we typically iterate over our datasets multiple times. Each full pass through the dataset is called an epoch. The following hyperparameter controls the number of epoch for which we train our model, and can be modified to train the model for longer:

[ ]:
epochs = 100

During each epoch, we train the model on our training dataset, then evaluate its performance on the validation dataset. The goal is for training loss to decrease and validation accuracy to increase, although you should expect noise in the training process. Generally, you want to train the model until the validation accuracy starts to plateau or even decrease.

[ ]:
for epoch in range(epochs):
    print(f'Epoch: {epoch}')
    train(train_dataloader)
    evaluate(val_dataloader)

Finally, we evaluate our performance on the test dataset. Note that we are only training our model on a toy dataset consisting of 100 images. If we instead trained on the full dataset (replace EuroSAT100 with EuroSAT in the above code), we would likely get much higher performance.

[ ]:
evaluate(test_dataloader)

Additional Reading

If you are new to machine learning and overwhelmed by all of the above terminology, or would like to gain a better understanding of some of the math that goes into machine learning, I would highly recommend a formal machine learning or deep learning course. The following official PyTorch tutorials are also worth exploring:

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources