Source code for torchgeo.models.vit
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
"""Pre-trained Vision Transformer models."""
from typing import Any, cast
import kornia.augmentation as K
import timm
import torch
from torch import nn
from torchvision.models._api import Weights, WeightsEnum
from .resnet import (
_landsat_etm_sr_bands,
_landsat_etm_toa_bands,
_landsat_oli_sr_bands,
_landsat_oli_tirs_toa_bands,
_landsat_tm_toa_bands,
_sentinel1_grd_bands,
_sentinel2_toa_bands,
_ssl4eo_s12_transforms_s1,
_ssl4eo_s12_transforms_s2_stats,
)
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/linear_BE_moco.py#L167
# https://github.com/zhu-xlab/SSL4EO-S12/blob/d2868adfada65e40910bfcedfc49bc3b20df2248/src/benchmark/transfer_classification/datasets/EuroSat/eurosat_dataset.py#L97
# Normalization either by 10K or channel-wise with band statistics
_zhu_xlab_transforms = K.AugmentationSequential(
K.Resize((256, 256)),
K.CenterCrop(224),
K.Normalize(mean=torch.tensor(0), std=torch.tensor(10000)),
data_keys=None,
)
# https://github.com/torchgeo/torchgeo/blob/8b53304d42c269f9001cb4e861a126dc4b462606/torchgeo/datamodules/ssl4eo_benchmark.py#L43
_ssl4eo_l_transforms = K.AugmentationSequential(
K.Normalize(mean=torch.tensor(0), std=torch.tensor(255)),
K.CenterCrop((224, 224)),
data_keys=None,
)
KEYS = {'norm.weight', 'norm.bias', 'head.weight', 'head.bias'}
[docs]class ViTSmall16_Weights(WeightsEnum): # type: ignore[misc]
"""Vision Transformer Small Patch Size 16 weights.
For `timm <https://github.com/huggingface/pytorch-image-models>`_
*vit_small_patch16_224* implementation.
.. versionadded:: 0.4
"""
LANDSAT_TM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_moco-a1c967d8.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 7,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/torchgeo/torchgeo',
'ssl_method': 'moco',
'bands': _landsat_tm_toa_bands,
},
)
LANDSAT_TM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_tm_toa_simclr-7c2d9799.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 7,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/torchgeo/torchgeo',
'ssl_method': 'simclr',
'bands': _landsat_tm_toa_bands,
},
)
LANDSAT_ETM_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_moco-26d19bcf.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 9,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/torchgeo/torchgeo',
'ssl_method': 'moco',
'bands': _landsat_etm_toa_bands,
},
)
LANDSAT_ETM_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_toa_simclr-34fb12cb.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 9,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/torchgeo/torchgeo',
'ssl_method': 'simclr',
'bands': _landsat_etm_toa_bands,
},
)
LANDSAT_ETM_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_moco-eaa4674e.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 6,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/torchgeo/torchgeo',
'ssl_method': 'moco',
'bands': _landsat_etm_sr_bands,
},
)
LANDSAT_ETM_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_etm_sr_simclr-a14c466a.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 6,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/torchgeo/torchgeo',
'ssl_method': 'simclr',
'bands': _landsat_etm_sr_bands,
},
)
LANDSAT_OLI_TIRS_TOA_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_moco-c7c2cceb.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 11,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/torchgeo/torchgeo',
'ssl_method': 'moco',
'bands': _landsat_oli_tirs_toa_bands,
},
)
LANDSAT_OLI_TIRS_TOA_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_tirs_toa_simclr-ad43e9a4.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 11,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/torchgeo/torchgeo',
'ssl_method': 'simclr',
'bands': _landsat_oli_tirs_toa_bands,
},
)
LANDSAT_OLI_SR_MOCO = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_moco-c9b8898d.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 7,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/torchgeo/torchgeo',
'ssl_method': 'moco',
'bands': _landsat_oli_sr_bands,
},
)
LANDSAT_OLI_SR_SIMCLR = Weights(
url='https://hf.co/torchgeo/ssl4eo_landsat/resolve/1c88bb51b6e17a21dde5230738fa38b74bd74f76/vits16_landsat_oli_sr_simclr-4e8f6102.pth',
transforms=_ssl4eo_l_transforms,
meta={
'dataset': 'SSL4EO-L',
'in_chans': 7,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2306.09424',
'repo': 'https://github.com/torchgeo/torchgeo',
'ssl_method': 'simclr',
'bands': _landsat_oli_sr_bands,
},
)
SENTINEL2_ALL_CLOSP = Weights(
url='https://huggingface.co/DarthReca/CLOSP-Visual/resolve/3bb8677c21dac56bea2dd7baa08d7871272db440/closp-vs_s2_encoder-1a3ee5a5.pth',
transforms=K.AugmentationSequential(
K.Normalize(mean=0, std=10000), K.Resize(224), data_keys=None
),
meta={
'dataset': 'CrisisLandMark',
'in_chans': 13,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2507.10403',
'repo': 'https://github.com/DarthReca/closp',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL2_ALL_DINO = Weights(
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_dino/resolve/5b41dd418a79de47ac9f5be3e035405a83818a62/vit_small_patch16_224_sentinel2_all_dino-36bcc127.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2211.07044',
'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
'ssl_method': 'dino',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL2_ALL_MOCO = Weights(
url='https://hf.co/torchgeo/vit_small_patch16_224_sentinel2_all_moco/resolve/1cb683f6c14739634cdfaaceb076529adf898c74/vit_small_patch16_224_sentinel2_all_moco-67c9032d.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2211.07044',
'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
'ssl_method': 'moco',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL2_ALL_MAE = Weights(
url='https://huggingface.co/wangyi111/SSL4EO-S12/resolve/75c72195d35201dc1fb210818993518c25da566b/B13_vits16_mae_ep99_enc.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2211.07044',
'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
'ssl_method': 'mae',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL2_ALL_FGMAE = Weights(
url='https://huggingface.co/wangyi111/FGMAE/resolve/24dd3077d7a99ecd454eaec7adb83d045d7fa122/B13_vits16_fgmae_ep99_enc.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2310.18653',
'repo': 'https://github.com/zhu-xlab/FGMAE',
'ssl_method': 'fg-mae',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL1_GRD_CLOSP = Weights(
url='https://huggingface.co/DarthReca/CLOSP-Visual/resolve/3bb8677c21dac56bea2dd7baa08d7871272db440/closp-vs_s1_encoder-180f1e6e.pth',
transforms=K.AugmentationSequential(K.Resize(224), data_keys=None),
meta={
'dataset': 'CrisisLandMark',
'in_chans': 2,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2507.10403',
'repo': 'https://github.com/DarthReca/closp',
'bands': _sentinel1_grd_bands,
},
)
SENTINEL1_GRD_MAE = Weights(
url='https://huggingface.co/wangyi111/SSL4EO-S12/resolve/75c72195d35201dc1fb210818993518c25da566b/B2_vits16_mae_ep99_enc.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2211.07044',
'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
'ssl_method': 'mae',
'bands': _sentinel1_grd_bands,
},
)
SENTINEL1_GRD_FGMAE = Weights(
url='https://huggingface.co/wangyi111/FGMAE/resolve/24dd3077d7a99ecd454eaec7adb83d045d7fa122/B2_vits16_fgmae_ep99_enc.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'model': 'vit_small_patch16_224',
'publication': 'https://arxiv.org/abs/2310.18653',
'repo': 'https://github.com/zhu-xlab/FGMAE',
'ssl_method': 'fg-mae',
'bands': _sentinel1_grd_bands,
},
)
[docs]class ViTBase16_Weights(WeightsEnum): # type: ignore[misc]
"""Vision Transformer Base Patch Size 16 weights.
For `timm <https://github.com/huggingface/pytorch-image-models>`_
*vit_base_patch16_224* implementation.
.. versionadded:: 0.7
"""
SENTINEL2_ALL_MAE = Weights(
url='https://huggingface.co/wangyi111/SSL4EO-S12/resolve/75c72195d35201dc1fb210818993518c25da566b/B13_vitb16_mae_ep99_enc.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'vit_base_patch16_224',
'publication': 'https://arxiv.org/abs/2211.07044',
'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
'ssl_method': 'mae',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL2_ALL_FGMAE = Weights(
url='https://huggingface.co/wangyi111/FGMAE/resolve/24dd3077d7a99ecd454eaec7adb83d045d7fa122/B13_vitb16_fgmae_ep99_enc.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'vit_base_patch16_224',
'publication': 'https://arxiv.org/abs/2310.18653',
'repo': 'https://github.com/zhu-xlab/FGMAE',
'ssl_method': 'fg-mae',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL1_GRD_MAE = Weights(
url='https://huggingface.co/wangyi111/SSL4EO-S12/resolve/75c72195d35201dc1fb210818993518c25da566b/B2_vitb16_mae_ep99_enc.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'model': 'vit_base_patch16_224',
'publication': 'https://arxiv.org/abs/2211.07044',
'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
'ssl_method': 'mae',
'bands': _sentinel1_grd_bands,
},
)
SENTINEL1_GRD_FGMAE = Weights(
url='https://huggingface.co/wangyi111/FGMAE/resolve/24dd3077d7a99ecd454eaec7adb83d045d7fa122/B2_vitb16_fgmae_ep99_enc.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'model': 'vit_base_patch16_224',
'publication': 'https://arxiv.org/abs/2310.18653',
'repo': 'https://github.com/zhu-xlab/FGMAE',
'ssl_method': 'fg-mae',
'bands': _sentinel1_grd_bands,
},
)
[docs]class ViTLarge16_Weights(WeightsEnum): # type: ignore[misc]
"""Vision Transformer Large Patch Size 16 weights.
For `timm <https://github.com/huggingface/pytorch-image-models>`_
*vit_large_patch16_224* implementation.
.. versionadded:: 0.7
"""
SENTINEL2_ALL_CLOSP = Weights(
url='https://huggingface.co/DarthReca/CLOSP-Visual/resolve/3bb8677c21dac56bea2dd7baa08d7871272db440/closp-vl_s2_encoder-4a4f026a.pth',
transforms=K.AugmentationSequential(
K.Normalize(mean=0, std=10000), K.Resize(224), data_keys=None
),
meta={
'dataset': 'CrisisLandMark',
'in_chans': 13,
'model': 'vit_large_patch16_224',
'publication': 'https://arxiv.org/abs/2507.10403',
'repo': 'https://github.com/DarthReca/closp',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL2_ALL_MAE = Weights(
url='https://huggingface.co/wangyi111/SSL4EO-S12/resolve/75c72195d35201dc1fb210818993518c25da566b/B13_vitl16_mae_ep99_enc.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'vit_large_patch16_224',
'publication': 'https://arxiv.org/abs/2211.07044',
'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
'ssl_method': 'mae',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL2_ALL_FGMAE = Weights(
url='https://huggingface.co/wangyi111/FGMAE/resolve/24dd3077d7a99ecd454eaec7adb83d045d7fa122/B13_vitl16_fgmae_ep99_enc.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'vit_large_patch16_224',
'publication': 'https://arxiv.org/abs/2310.18653',
'repo': 'https://github.com/zhu-xlab/FGMAE',
'ssl_method': 'fg-mae',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL1_GRD_CLOSP = Weights(
url='https://huggingface.co/DarthReca/CLOSP-Visual/resolve/3bb8677c21dac56bea2dd7baa08d7871272db440/closp-vl_s1_encoder-6f88d037.pth',
transforms=K.AugmentationSequential(K.Resize(224), data_keys=None),
meta={
'dataset': 'CrisisLandMark',
'in_chans': 2,
'model': 'vit_large_patch16_224',
'publication': 'https://arxiv.org/abs/2507.10403',
'repo': 'https://github.com/DarthReca/closp',
'bands': _sentinel1_grd_bands,
},
)
SENTINEL1_GRD_MAE = Weights(
url='https://huggingface.co/wangyi111/SSL4EO-S12/resolve/75c72195d35201dc1fb210818993518c25da566b/B2_vitl16_mae_ep99_enc.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'model': 'vit_large_patch16_224',
'publication': 'https://arxiv.org/abs/2211.07044',
'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
'ssl_method': 'mae',
'bands': _sentinel1_grd_bands,
},
)
SENTINEL1_GRD_FGMAE = Weights(
url='https://huggingface.co/wangyi111/FGMAE/resolve/24dd3077d7a99ecd454eaec7adb83d045d7fa122/B2_vitl16_fgmae_ep99_enc.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'model': 'vit_large_patch16_224',
'publication': 'https://arxiv.org/abs/2310.18653',
'repo': 'https://github.com/zhu-xlab/FGMAE',
'ssl_method': 'fg-mae',
'bands': _sentinel1_grd_bands,
},
)
[docs]class ViTHuge14_Weights(WeightsEnum): # type: ignore[misc]
"""Vision Transformer Huge Patch Size 14 weights.
For `timm <https://github.com/huggingface/pytorch-image-models>`_
*vit_huge_patch14_224* implementation.
.. versionadded:: 0.7
"""
SENTINEL2_ALL_MAE = Weights(
url='https://huggingface.co/wangyi111/SSL4EO-S12/resolve/75c72195d35201dc1fb210818993518c25da566b/B13_vith14_mae_ep199_enc.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'vit_huge_patch14_224',
'publication': 'https://arxiv.org/abs/2211.07044',
'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
'ssl_method': 'mae',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL2_ALL_FGMAE = Weights(
url='https://huggingface.co/wangyi111/FGMAE/resolve/24dd3077d7a99ecd454eaec7adb83d045d7fa122/B13_vith14_fgmae_ep399_enc.pth',
transforms=_zhu_xlab_transforms,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'model': 'vit_huge_patch14_224',
'publication': 'https://arxiv.org/abs/2310.18653',
'repo': 'https://github.com/zhu-xlab/FGMAE',
'ssl_method': 'fg-mae',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL1_GRD_MAE = Weights(
url='https://huggingface.co/wangyi111/SSL4EO-S12/resolve/75c72195d35201dc1fb210818993518c25da566b/B2_vith14_mae_ep199_enc.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'model': 'vit_huge_patch14_224',
'publication': 'https://arxiv.org/abs/2211.07044',
'repo': 'https://github.com/zhu-xlab/SSL4EO-S12',
'ssl_method': 'mae',
'bands': _sentinel1_grd_bands,
},
)
SENTINEL1_GRD_FGMAE = Weights(
url='https://huggingface.co/wangyi111/FGMAE/resolve/24dd3077d7a99ecd454eaec7adb83d045d7fa122/B2_vith14_fgmae_ep399_enc.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'model': 'vit_huge_patch14_224',
'publication': 'https://arxiv.org/abs/2310.18653',
'repo': 'https://github.com/zhu-xlab/FGMAE',
'ssl_method': 'fg-mae',
'bands': _sentinel1_grd_bands,
},
)
[docs]class ViTSmall14_DINOv2_Weights(WeightsEnum): # type: ignore[misc]
"""Vision Transformer Small Patch Size 14 (DINOv2) weights.
For `timm <https://github.com/huggingface/pytorch-image-models>`_
*vit_small_patch14_dinov2* implementation.
.. versionadded:: 0.7
"""
SENTINEL2_ALL_SOFTCON = Weights(
url='https://huggingface.co/wangyi111/softcon/resolve/bae909781911f8ec034b4b959992fae17b973c0c/B13_vits14_softcon_enc.pth',
transforms=_ssl4eo_s12_transforms_s2_stats,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'img_size': 224,
'model': 'vit_small_patch14_dinov2',
'publication': 'https://arxiv.org/abs/2405.20462',
'repo': 'https://github.com/zhu-xlab/softcon',
'ssl_method': 'softcon',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL1_GRD_SOFTCON = Weights(
url='https://huggingface.co/wangyi111/softcon/resolve/bae909781911f8ec034b4b959992fae17b973c0c/B2_vits14_softcon_enc.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'img_size': 224,
'model': 'vit_small_patch14_dinov2',
'publication': 'https://arxiv.org/abs/2405.20462',
'repo': 'https://github.com/zhu-xlab/softcon',
'ssl_method': 'softcon',
'bands': _sentinel1_grd_bands,
},
)
[docs]class ViTBase14_DINOv2_Weights(WeightsEnum): # type: ignore[misc]
"""Vision Transformer Base Patch Size 14 (DINOv2) weights.
For `timm <https://github.com/huggingface/pytorch-image-models>`_
*vit_base_patch14_dinov2* implementation.
.. versionadded:: 0.7
"""
SENTINEL2_ALL_SOFTCON = Weights(
url='https://huggingface.co/wangyi111/softcon/resolve/bae909781911f8ec034b4b959992fae17b973c0c/B13_vitb14_softcon_enc.pth',
transforms=_ssl4eo_s12_transforms_s2_stats,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 13,
'img_size': 224,
'model': 'vit_base_patch14_dinov2',
'publication': 'https://arxiv.org/abs/2405.20462',
'repo': 'https://github.com/zhu-xlab/softcon',
'ssl_method': 'softcon',
'bands': _sentinel2_toa_bands,
},
)
SENTINEL1_GRD_SOFTCON = Weights(
url='https://huggingface.co/wangyi111/softcon/resolve/bae909781911f8ec034b4b959992fae17b973c0c/B2_vitb14_softcon_enc.pth',
transforms=_ssl4eo_s12_transforms_s1,
meta={
'dataset': 'SSL4EO-S12',
'in_chans': 2,
'img_size': 224,
'model': 'vit_base_patch14_dinov2',
'publication': 'https://arxiv.org/abs/2405.20462',
'repo': 'https://github.com/zhu-xlab/softcon',
'ssl_method': 'softcon',
'bands': _sentinel1_grd_bands,
},
)
[docs]def vit_small_patch16_224(
weights: ViTSmall16_Weights | None = None, *args: Any, **kwargs: Any
) -> nn.Module:
"""Vision Transform (ViT) small patch size 16 model.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2010.11929
.. versionadded:: 0.4
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`.
**kwargs: Additional keyword arguments to pass to :func:`timm.create_model`.
Returns:
A ViT small 16 model.
"""
if weights:
kwargs['in_chans'] = weights.meta['in_chans']
model = timm.create_model('vit_small_patch16_224', *args, **kwargs)
if kwargs.get('features_only', False):
target_model = cast(nn.Module, model.model)
else:
target_model = model
if weights:
missing_keys, unexpected_keys = target_model.load_state_dict(
weights.get_state_dict(progress=True), strict=False
)
assert set(missing_keys) <= KEYS
# used when features_only = True
assert set(unexpected_keys) <= KEYS
return model
[docs]def vit_base_patch16_224(
weights: ViTBase16_Weights | None = None, *args: Any, **kwargs: Any
) -> nn.Module:
"""Vision Transform (ViT) base patch size 16 model.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2010.11929
.. versionadded:: 0.7
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`.
**kwargs: Additional keyword arguments to pass to :func:`timm.create_model`.
Returns:
A ViT base 16 model.
"""
if weights:
kwargs['in_chans'] = weights.meta['in_chans']
model = timm.create_model('vit_base_patch16_224', *args, **kwargs)
if kwargs.get('features_only', False):
target_model = cast(nn.Module, model.model)
else:
target_model = model
if weights:
missing_keys, unexpected_keys = target_model.load_state_dict(
weights.get_state_dict(progress=True), strict=False
)
assert set(missing_keys) <= KEYS
assert set(unexpected_keys) <= KEYS
return model
[docs]def vit_large_patch16_224(
weights: ViTLarge16_Weights | None = None, *args: Any, **kwargs: Any
) -> nn.Module:
"""Vision Transform (ViT) large patch size 16 model.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2010.11929
.. versionadded:: 0.7
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`.
**kwargs: Additional keyword arguments to pass to :func:`timm.create_model`.
Returns:
A ViT large 16 model.
"""
if weights:
kwargs['in_chans'] = weights.meta['in_chans']
model = timm.create_model('vit_large_patch16_224', *args, **kwargs)
if kwargs.get('features_only', False):
target_model = cast(nn.Module, model.model)
else:
target_model = model
if weights:
missing_keys, unexpected_keys = target_model.load_state_dict(
weights.get_state_dict(progress=True), strict=False
)
assert set(missing_keys) <= KEYS
assert set(unexpected_keys) <= KEYS
return model
[docs]def vit_huge_patch14_224(
weights: ViTHuge14_Weights | None = None, *args: Any, **kwargs: Any
) -> nn.Module:
"""Vision Transform (ViT) huge patch size 14 model.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2010.11929
.. versionadded:: 0.7
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`.
**kwargs: Additional keyword arguments to pass to :func:`timm.create_model`.
Returns:
A ViT huge 14 model.
"""
if weights:
kwargs['in_chans'] = weights.meta['in_chans']
model = timm.create_model('vit_huge_patch14_224', *args, **kwargs)
if kwargs.get('features_only', False):
target_model = cast(nn.Module, model.model)
else:
target_model = model
if weights:
missing_keys, unexpected_keys = target_model.load_state_dict(
weights.get_state_dict(progress=True), strict=False
)
assert set(missing_keys) <= KEYS
assert set(unexpected_keys) <= KEYS
return model
[docs]def vit_small_patch14_dinov2(
weights: ViTSmall14_DINOv2_Weights | None = None, *args: Any, **kwargs: Any
) -> nn.Module:
"""Vision Transform (ViT) small patch size 14 model for DINOv2.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2304.07193
.. versionadded:: 0.7
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`.
**kwargs: Additional keyword arguments to pass to :func:`timm.create_model`.
Returns:
A DINOv2 ViT small 14 model.
"""
if weights:
kwargs['in_chans'] = weights.meta['in_chans']
kwargs['img_size'] = weights.meta['img_size']
model = timm.create_model('vit_small_patch14_dinov2', *args, **kwargs)
if kwargs.get('features_only', False):
target_model = cast(nn.Module, model.model)
else:
target_model = model
if weights:
missing_keys, unexpected_keys = target_model.load_state_dict(
weights.get_state_dict(progress=True), strict=False
)
assert set(missing_keys) <= KEYS
assert set(unexpected_keys) <= KEYS
return model
[docs]def vit_base_patch14_dinov2(
weights: ViTBase14_DINOv2_Weights | None = None, *args: Any, **kwargs: Any
) -> nn.Module:
"""Vision Transform (ViT) base patch size 14 model for DINOv2.
If you use this model in your research, please cite the following paper:
* https://arxiv.org/abs/2304.07193
.. versionadded:: 0.7
Args:
weights: Pre-trained model weights to use.
*args: Additional arguments to pass to :func:`timm.create_model`.
**kwargs: Additional keyword arguments to pass to :func:`timm.create_model`.
Returns:
A DINOv2 ViT base 14 model.
"""
if weights:
kwargs['in_chans'] = weights.meta['in_chans']
kwargs['img_size'] = weights.meta['img_size']
model = timm.create_model('vit_base_patch14_dinov2', *args, **kwargs)
if kwargs.get('features_only', False):
target_model = cast(nn.Module, model.model)
else:
target_model = model
if weights:
missing_keys, unexpected_keys = target_model.load_state_dict(
weights.get_state_dict(progress=True), strict=False
)
assert set(missing_keys) <= KEYS
assert set(unexpected_keys) <= KEYS
return model