[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
Pretrained Weights¶
Written by: Nils Lehmann
In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions’ recently introduced Multi-Weight API.
It’s recommended to run this notebook on Google Colab if you don’t have your own GPU. Click the “Open in Colab” button above to get started.
Setup¶
First, we install TorchGeo.
[ ]:
%pip install torchgeo
Imports¶
Next, we import TorchGeo.
[ ]:
%matplotlib inline
import timm
from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18
Pretrained Weights¶
Pretrained weights for torchgeo.models
are available and sorted by satellite or sensor type: sensor-agnostic, Landsat, NAIP, Sentinel-1, and Sentinel-2. Refer to the model documentation for a complete list of weights. Choose from the provided pre-trained weights based on your specific use case.
While some weights only accept RGB channel input, some weights have been pretrained on Sentinel-2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel-2 data.
To use these weights, you can load them as follows:
[ ]:
all_weights = ResNet18_Weights.SENTINEL2_ALL_MOCO
rgb_weights = ResNet50_Weights.SENTINEL2_RGB_MOCO
Weight Metadata¶
This set of weights is a torchvision WeightEnum
and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights.
Let’s inspect the metadata of the two pretrained weights we have just loaded:
[ ]:
# ResNet18_Weights.SENTINEL2_ALL_MOCO
print(f'Weight URL: {all_weights.url}')
print('Weight metadata:')
for key, value in all_weights.meta.items():
print(f' {key}: {value}')
[ ]:
# ResNet50_Weights.SENTINEL2_RGB_MOCO
print(f'Weight URL: {rgb_weights.url}')
print('Weight metadata:')
for key, value in rgb_weights.meta.items():
print(f' {key}: {value}')
Using Pretrained Weights for Training¶
We can load the pretrained weights ResNet18_Weights.SENTINEL2_ALL_MOCO
into a ResNet-18 model like this:
[ ]:
model = resnet18(all_weights)
Here, TorchGeo simply acts as a wrapper around timm. If you don’t want to use this wrapper, you can create a timm model directly and load the pretrained weights from TorchGeo as follows:
[ ]:
in_chans = all_weights.meta['in_chans']
model = timm.create_model('resnet18', in_chans=in_chans, num_classes=10)
model.load_state_dict(all_weights.get_state_dict(progress=True), strict=False)
To train our pretrained model on a dataset we will make use of Lightning’s Trainer. For a more elaborate explanation of how TorchGeo uses Lightning, check out this next tutorial.