Shortcuts

Source code for torchgeo.models.earthloc

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

"""EarthLoc."""

import math
from typing import Any

import kornia.augmentation as K
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torch import Tensor
from torchvision.models._api import Weights, WeightsEnum

# Note the images used are Sentinel-2 Cloudless RGB Mosaics from https://s2maps.eu/
_earthloc_sentinel2_bands = ['B4', 'B3', 'B2']

# https://github.com/gmberton/EarthLoc/blob/2da231ae7ec9764fac6cde2aa88a17db23c1bb6a/datasets/train_dataset.py#L43
# https://github.com/gmberton/EarthLoc/blob/2da231ae7ec9764fac6cde2aa88a17db23c1bb6a/augmentations.py#L40
# Divide by 255 and normalize with ImageNet mean and std
_earthloc_transforms = K.AugmentationSequential(
    K.Normalize(mean=torch.tensor(0.0), std=torch.tensor(255.0)),
    K.Normalize(
        mean=torch.tensor([0.485, 0.456, 0.406]),
        std=torch.tensor([0.229, 0.224, 0.225]),
    ),
    K.Resize((320, 320)),
    data_keys=None,
)


class FeatureMixerLayer(nn.Module):
    """Feature Mixer Layer in the MixVPR architecture.

    Adapted from https://github.com/gmberton/EarthLoc. Copyright (c) 2024 Gabriele Berton

    .. versionadded:: 0.8
    """

    def __init__(self, input_dim: int, mlp_ratio: int = 1) -> None:
        """Initialize the FeatureMixerLayer.

        Args:
            input_dim: Input dimension of the feature maps.
            mlp_ratio: Ratio of the mid projection layer in the mlp mixer block.
        """
        super().__init__()
        self.mix = nn.Sequential(
            nn.LayerNorm(input_dim),
            nn.Linear(input_dim, int(input_dim * mlp_ratio)),
            nn.ReLU(),
            nn.Linear(int(input_dim * mlp_ratio), input_dim),
        )

        for m in self.modules():
            if isinstance(m, (nn.Linear)):
                nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass of the FeatureMixerLayer.

        Args:
            x: Input tensor of shape (batch_size, num_features, feature_dim).

        Returns:
            Output tensor after applying the feature mixer.
        """
        x = x + self.mix(x)
        return x


class MixVPR(nn.Module):
    """MixVPR model for generating feature descriptors.

    Adapted from https://github.com/gmberton/EarthLoc. Copyright (c) 2024 Gabriele Berton

    If you use this model in your research, please cite the following paper:

    * https://arxiv.org/abs/2303.02190

    .. versionadded:: 0.8
    """

    def __init__(
        self,
        in_channels: int = 1024,
        in_h: int = 20,
        in_w: int = 20,
        out_channels: int = 512,
        mix_depth: int = 1,
        mlp_ratio: int = 1,
        out_rows: int = 4,
    ) -> None:
        """Initialize the MixVPR model.

        Args:
            in_channels: Number of input channels in the feature maps.
            in_h: Height of the input feature maps.
            in_w: Width of the input feature maps.
            out_channels: Number of output channels after depth-wise projection.
            mix_depth: Number of stacked FeatureMixer layers.
            mlp_ratio: Ratio of the mid projection layer in the mixer block.
            out_rows: Row-wise projection dimension.
        """
        super().__init__()
        self.in_h = in_h
        self.in_w = in_w
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.out_rows = out_rows
        self.mix_depth = mix_depth
        self.mlp_ratio = mlp_ratio

        hw = in_h * in_w
        self.mix = nn.Sequential(
            *[
                FeatureMixerLayer(input_dim=hw, mlp_ratio=mlp_ratio)
                for _ in range(self.mix_depth)
            ]
        )
        self.channel_proj = nn.Linear(in_channels, out_channels)
        self.row_proj = nn.Linear(hw, out_rows)

    def forward(self, x: Tensor) -> Tensor:
        """Forward pass of the MixVPR encoder.

        Args:
            x: Input 2D image embeddings of shape (b, c, h, w).

        Returns:
            Output feature descriptor tensor of shape (b, d).
        """
        x = rearrange(x, 'b c h w -> b c (h w)')
        x = self.mix(x)
        x = rearrange(x, 'b c d -> b d c')
        x = self.channel_proj(x)
        x = rearrange(x, 'b d c -> b c d')
        x = self.row_proj(x)
        x = rearrange(x, 'b c d -> b (c d)')
        x = F.normalize(x, p=2, dim=1)
        return x


[docs]class EarthLoc(nn.Module): """EarthLoc model for generating feature descriptors from satellite imagery. Adapted from https://github.com/gmberton/EarthLoc. Copyright (c) 2024 Gabriele Berton If you use this model in your research, please cite the following paper: * https://arxiv.org/abs/2403.06758 .. versionadded:: 0.8 """
[docs] def __init__( self, in_channels: int = 3, image_size: int = 320, desc_dim: int = 4096, backbone: str = 'resnet50', pretrained: bool = True, ) -> None: """Initialize the EarthLoc model. Args: in_channels: Number of input channels in the images (default: 3 for RGB). image_size: Size of the input images (assumed square). desc_dim: Dimension of the final output feature descriptor. backbone: Backbone model to use for feature extraction (default: "resnet50"). pretrained: Whether to use pre-trained weights for the backbone model. """ super().__init__() self.image_size = image_size self.backbone = timm.create_model( backbone, pretrained=pretrained, in_chans=in_channels, num_classes=0, global_pool='', ) self.backbone.layer4 = nn.Identity() out_channels = desc_dim // 4 self.aggregator = MixVPR( in_channels=1024, in_h=math.ceil(image_size / 16), in_w=math.ceil(image_size / 16), out_channels=out_channels, mix_depth=4, mlp_ratio=1, out_rows=4, ) self.fc = nn.Linear(desc_dim, desc_dim) self.desc_dim = desc_dim # Dimension of final descriptor
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass of the EarthLoc model. Args: x: Input tensor of shape (b, c, h, w). Returns: Output feature descriptor tensor of shape (b, desc_dim). """ x = self.backbone(x) x = self.aggregator(x) x = self.fc(x) x = F.normalize(x, p=2, dim=-1) return x
[docs]class EarthLoc_Weights(WeightsEnum): # type: ignore[misc] """EarthLoc weights.""" SENTINEL2_RESNET50 = Weights( url='https://huggingface.co/torchgeo/earthloc/resolve/53a4bb90a7754b12f44986521ac7a711b4795959/earthloc-8b632e30.pth', transforms=_earthloc_transforms, meta={ 'dataset': 'EarthLoc', 'in_chans': 3, 'image_size': 320, 'desc_dim': 4096, 'encoder': 'resnet50', 'bands': _earthloc_sentinel2_bands, 'publication': 'https://arxiv.org/abs/2403.06758', 'repo': 'https://github.com/gmberton/EarthLoc', }, )
[docs]def earthloc( weights: EarthLoc_Weights | None = None, *args: Any, **kwargs: Any ) -> EarthLoc: """EarthLoc model. If you use this model in your research, please cite the following paper: * https://arxiv.org/abs/2403.06758 .. versionadded:: 0.8 Args: weights: Pre-trained model weights to use. *args: Additional arguments to pass to :class:`EarthLoc`. **kwargs: Additional keyword arguments to pass to :class:`EarthLoc`. Returns: An EarthLoc model. """ if weights: kwargs |= { 'in_channels': weights.meta['in_chans'], 'image_size': weights.meta['image_size'], 'desc_dim': weights.meta['desc_dim'], 'backbone': weights.meta['encoder'], 'pretrained': False, } model = EarthLoc(*args, **kwargs) model.load_state_dict(weights.get_state_dict(progress=True), strict=True) else: model = EarthLoc(*args, **kwargs) return model

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources