{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "p63J-QmUrMN-" }, "outputs": [], "source": [ "# Copyright (c) TorchGeo Contributors. All rights reserved.\n", "# Licensed under the MIT License." ] }, { "cell_type": "markdown", "metadata": { "id": "XRSkMFqyrMOE" }, "source": [ "# Pretrained Weights\n", "\n", "_Written by: Nils Lehmann_\n", "\n", "In this tutorial, we demonstrate some available pretrained weights in TorchGeo. The implementation follows torchvisions' recently introduced [Multi-Weight API](https://pytorch.org/blog/introducing-torchvision-new-multi-weight-support-api/).\n", "\n", "It's recommended to run this notebook on Google Colab if you don't have your own GPU. Click the \"Open in Colab\" button above to get started." ] }, { "cell_type": "markdown", "metadata": { "id": "NBa5RPAirMOF" }, "source": [ "## Setup\n", "\n", "First, we install TorchGeo." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5AIQ1B9DrMOG", "outputId": "6bf360ea-8f60-45cf-c96e-0eac54818079" }, "outputs": [], "source": [ "%pip install torchgeo" ] }, { "cell_type": "markdown", "metadata": { "id": "IcCOnzVLrMOI" }, "source": [ "## Imports\n", "\n", "Next, we import TorchGeo." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rjEGiiurrMOI" }, "outputs": [], "source": [ "%matplotlib inline\n", "\n", "import timm\n", "\n", "from torchgeo.models import ResNet18_Weights, ResNet50_Weights, resnet18" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pretrained Weights\n", "\n", "Pretrained weights for `torchgeo.models` are available and sorted by satellite or sensor type: sensor-agnostic, Landsat, NAIP, Sentinel-1, and Sentinel-2. Refer to the [model documentation](https://torchgeo.readthedocs.io/en/stable/api/models.html#pretrained-weights) for a complete list of weights. Choose from the provided pre-trained weights based on your specific use case.\n", "\n", "While some weights only accept RGB channel input, some weights have been pretrained on Sentinel-2 imagery with 13 input channels and can hence prove useful for transfer learning tasks involving Sentinel-2 data.\n", "\n", "To use these weights, you can load them as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RZ8MPYH1rMON", "outputId": "fa683b8f-da21-4f26-ca3a-46163c9f12bf" }, "outputs": [], "source": [ "all_weights = ResNet18_Weights.SENTINEL2_ALL_MOCO\n", "\n", "rgb_weights = ResNet50_Weights.SENTINEL2_RGB_MOCO" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Weight Metadata\n", "\n", "This set of weights is a torchvision `WeightEnum` and holds information such as the download url link or additional meta data. TorchGeo takes care of the downloading and initialization of models with a desired set of weights.\n", "\n", "Let's inspect the metadata of the two pretrained weights we have just loaded:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ResNet18_Weights.SENTINEL2_ALL_MOCO\n", "print(f'Weight URL: {all_weights.url}')\n", "\n", "print('Weight metadata:')\n", "for key, value in all_weights.meta.items():\n", " print(f' {key}: {value}')" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# ResNet50_Weights.SENTINEL2_RGB_MOCO\n", "print(f'Weight URL: {rgb_weights.url}')\n", "\n", "print('Weight metadata:')\n", "for key, value in rgb_weights.meta.items():\n", " print(f' {key}: {value}')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Using Pretrained Weights for Training" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can load the pretrained weights `ResNet18_Weights.SENTINEL2_ALL_MOCO` into a ResNet-18 model like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = resnet18(all_weights)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here, TorchGeo simply acts as a wrapper around [timm](https://github.com/huggingface/pytorch-image-models). If you don't want to use this wrapper, you can create a timm model directly and load the pretrained weights from TorchGeo as follows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "in_chans = all_weights.meta['in_chans']\n", "model = timm.create_model('resnet18', in_chans=in_chans, num_classes=10)\n", "model.load_state_dict(all_weights.get_state_dict(progress=True), strict=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To train our pretrained model on a dataset we will make use of Lightning's [Trainer](https://lightning.ai/docs/pytorch/stable/common/trainer.html). For a more elaborate explanation of how TorchGeo uses Lightning, check out [this next tutorial](https://torchgeo.readthedocs.io/en/stable/tutorials/trainers.html)." ] } ], "metadata": { "accelerator": "GPU", "execution": { "timeout": 1200 }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.0" }, "vscode": { "interpreter": { "hash": "b058dd71d0e7047e70e62f655d92ec955f772479bbe5e5addd202027292e8f60" } } }, "nbformat": 4, "nbformat_minor": 4 }