Source code for torchgeo.models.unet
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
"""Pre-trained U-Net models."""
from typing import Any
import kornia.augmentation as K
import segmentation_models_pytorch as smp
import torch
from segmentation_models_pytorch import Unet
from torchvision.models._api import Weights, WeightsEnum
# Specified in https://github.com/fieldsoftheworld/ftw-baselines
# First 4 S2 bands are for image t1 and last 4 bands are for image t2
_ftw_sentinel2_bands = ['B4', 'B3', 'B2', 'B8A', 'B4', 'B3', 'B2', 'B8A']
# https://github.com/fieldsoftheworld/ftw-baselines/blob/main/src/ftw/datamodules.py
# Normalization by 3k (for S2 uint16 input)
_ftw_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0.0), std=torch.tensor(3000.0)), data_keys=None
)
# No normalization used see: https://github.com/Restor-Foundation/tcd/blob/main/src/tcd_pipeline/data/datamodule.py#L145
_tcd_bands = ['R', 'G', 'B']
_tcd_transforms = K.AugmentationSequential(K.Resize(size=(1024, 1024)), data_keys=None)
[docs]class Unet_Weights(WeightsEnum): # type: ignore[misc]
"""U-Net weights.
For `smp <https://github.com/qubvel-org/segmentation_models.pytorch>`_
*Unet* implementation.
.. versionadded:: 0.8
"""
SENTINEL2_2CLASS_FTW = Weights(
url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/commercial/2-class/sentinel2_unet_effb3-9c04b7c6.pth',
transforms=_ftw_transforms,
meta={
'dataset': 'FTW',
'in_chans': 8,
'num_classes': 2,
'model': 'U-Net',
'encoder': 'efficientnet-b3',
'publication': 'https://arxiv.org/abs/2409.16252',
'repo': 'https://github.com/fieldsoftheworld/ftw-baselines',
'bands': _ftw_sentinel2_bands,
'license': 'CC-BY-4.0',
},
)
SENTINEL2_3CLASS_FTW = Weights(
url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/commercial/3-class/sentinel2_unet_effb3-5d591cbb.pth',
transforms=_ftw_transforms,
meta={
'dataset': 'FTW',
'in_chans': 8,
'num_classes': 3,
'model': 'U-Net',
'encoder': 'efficientnet-b3',
'publication': 'https://arxiv.org/abs/2409.16252',
'repo': 'https://github.com/fieldsoftheworld/ftw-baselines',
'bands': _ftw_sentinel2_bands,
'license': 'CC-BY-4.0',
},
)
SENTINEL2_2CLASS_NC_FTW = Weights(
url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/noncommercial/2-class/sentinel2_unet_effb3-bf010a31.pth',
transforms=_ftw_transforms,
meta={
'dataset': 'FTW',
'in_chans': 8,
'num_classes': 2,
'model': 'U-Net',
'encoder': 'efficientnet-b3',
'publication': 'https://arxiv.org/abs/2409.16252',
'repo': 'https://github.com/fieldsoftheworld/ftw-baselines',
'bands': _ftw_sentinel2_bands,
'license': 'non-commercial',
},
)
SENTINEL2_3CLASS_NC_FTW = Weights(
url='https://huggingface.co/torchgeo/ftw/resolve/d2fdab6ea9d9cd38b491292cc9a5c8642533cef5/noncommercial/3-class/sentinel2_unet_effb3-ed36f465.pth',
transforms=_ftw_transforms,
meta={
'dataset': 'FTW',
'in_chans': 8,
'num_classes': 3,
'model': 'U-Net',
'encoder': 'efficientnet-b3',
'publication': 'https://arxiv.org/abs/2409.16252',
'repo': 'https://github.com/fieldsoftheworld/ftw-baselines',
'bands': _ftw_sentinel2_bands,
'license': 'non-commercial',
},
)
OAM_RGB_RESNET50_TCD = Weights(
url='https://hf.co/isaaccorley/unet_resnet50_oam_rgb_tcd/resolve/5df2fe5a0e80fd6e12939686b7370c53f73bf389/unet_resnet50_oam_rgb_tcd-72b9b753.pth',
transforms=_tcd_transforms,
meta={
'dataset': 'OAM-TCD',
'in_chans': 3,
'num_classes': 2,
'model': 'U-Net',
'encoder': 'resnet50',
'publication': 'https://arxiv.org/abs/2407.11743',
'repo': 'https://github.com/restor-foundation/tcd',
'bands': _tcd_bands,
'classes': ('background', 'tree-canopy'),
'input_shape': (3, 1024, 1024),
'resolution': 0.1,
'license': 'CC-BY-NC-4.0',
},
)
OAM_RGB_RESNET34_TCD = Weights(
url='https://hf.co/isaaccorley/unet_resnet34_oam_rgb_tcd/resolve/40c914bbcbe43a6a87c81adb0a22ff2d4a53204d/unet_resnet34_oam_rgb_tcd-72b9b753.pth',
transforms=_tcd_transforms,
meta={
'dataset': 'OAM-TCD',
'in_chans': 3,
'num_classes': 2,
'model': 'U-Net',
'encoder': 'resnet34',
'publication': 'https://arxiv.org/abs/2407.11743',
'repo': 'https://github.com/restor-foundation/tcd',
'bands': _tcd_bands,
'classes': ('background', 'tree-canopy'),
'input_shape': (3, 1024, 1024),
'resolution': 0.1,
'license': 'CC-BY-NC-4.0',
},
)
[docs]def unet(
weights: Unet_Weights | None = None,
classes: int | None = None,
*args: Any,
**kwargs: Any,
) -> Unet:
"""U-Net model.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/1505.04597
.. versionadded:: 0.8
Args:
weights: Pre-trained model weights to use.
classes: Number of output classes. If not specified, the number of
classes will be inferred from the weights.
*args: Additional arguments to pass to ``segmentation_models_pytorch.create_model``
**kwargs: Additional keyword arguments to pass to ``segmentation_models_pytorch.create_model``
Returns:
A U-Net model.
"""
kwargs['arch'] = 'Unet'
if weights:
kwargs['encoder_weights'] = None
kwargs['in_channels'] = weights.meta['in_chans']
kwargs['encoder_name'] = weights.meta['encoder']
kwargs['classes'] = weights.meta['num_classes'] if classes is None else classes
else:
kwargs['classes'] = 1 if classes is None else classes
model: Unet = smp.create_model(*args, **kwargs)
if weights:
state_dict = weights.get_state_dict(progress=True)
# Load full pretrained model
if kwargs['classes'] == weights.meta['num_classes']:
missing_keys, unexpected_keys = model.load_state_dict(
state_dict, strict=True
)
# Random initialize segmentation head for new task
else:
del state_dict['segmentation_head.0.weight']
del state_dict['segmentation_head.0.bias']
missing_keys, unexpected_keys = model.load_state_dict(
state_dict, strict=False
)
assert set(missing_keys) <= {
'segmentation_head.0.weight',
'segmentation_head.0.bias',
}
assert not unexpected_keys
return model