Shortcuts

Source code for torchgeo.models.croma

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

# Code based on https://github.com/antofuller/CROMA under MIT License

"""CROMA model."""

import itertools
import math
from collections.abc import Sequence
from typing import Any

import torch
from einops import rearrange
from torch import Tensor, einsum, nn
from torchvision.models._api import Weights, WeightsEnum


[docs]class CROMA(nn.Module): """Pretrained CROMA model. Corresponds to the pretrained CROMA model found in the CROMA repository: * https://github.com/antofuller/CROMA/blob/main/pretrain_croma.py If you use this model in your research, please cite the following paper: * https://arxiv.org/abs/2311.00566 """ valid_modalities = ('sar', 'optical')
[docs] def __init__( self, modalities: Sequence[str] = ['sar', 'optical'], encoder_dim: int = 768, encoder_depth: int = 12, num_heads: int = 16, patch_size: int = 8, image_size: int = 120, ) -> None: """Initialize the CROMA model. Args: modalities: List of modalities used during forward pass, list can contain 'sar', 'optical', or both. encoder_dim: Dimension of the encoder. encoder_depth: Depth of the encoder. num_heads: Number of heads for the multi-head attention, should be power of 2. patch_size: Size of the patches. image_size: Size of the input images, CROMA was trained on 120x120 images, must be a multiple of 8. Raises: AssertionError: If any arguments are not valid. """ super().__init__() for modality in modalities: assert modality in self.valid_modalities, ( f'{modality} is not a valid modality' ) assert image_size % 8 == 0, 'image_size must be a multiple of 8' assert num_heads % 2 == 0, 'num_heads must be a power of 2' self.modalities = modalities self.encoder_dim = encoder_dim self.encoder_depth = encoder_depth self.num_heads = num_heads self.patch_size = patch_size self.image_size = image_size self.num_patches = int((image_size / 8) ** 2) self.s1_channels = 2 # fixed at 2 SAR backscatter channels self.s2_channels = 12 # fixed at 12 multispectral optical channels self.attn_bias = get_2dalibi( num_heads=self.num_heads, num_patches=self.num_patches ) def initialize_encoder( encoder_dim: int, encoder_depth: int, in_channels: int ) -> tuple[nn.Module, nn.Module]: """Initialize the encoder and GAP-FFN for a given modality. Args: encoder_dim: Dimension of the encoder. encoder_depth: Depth of the encoder. in_channels: Number of input channels. Returns: Tuple containing the encoder and GAP-FFN. """ encoder = ViT(dim=encoder_dim, depth=encoder_depth, in_channels=in_channels) gap_ffn = nn.Sequential( nn.LayerNorm(encoder_dim), nn.Linear(encoder_dim, int(4 * encoder_dim)), nn.GELU(), nn.Linear(int(4 * encoder_dim), encoder_dim), ) return encoder, gap_ffn if 'sar' in modalities: self.s1_encoder, self.s1_GAP_FFN = initialize_encoder( encoder_dim, int(encoder_depth / 2), self.s1_channels ) if 'optical' in modalities: self.s2_encoder, self.s2_GAP_FFN = initialize_encoder( encoder_dim, encoder_depth, self.s2_channels ) if set(self.modalities) == {'sar', 'optical'}: self.joint_encoder = BaseTransformerCrossAttn( dim=encoder_dim, depth=int(encoder_depth / 2), num_heads=num_heads )
[docs] def forward( self, x_sar: Tensor | None = None, x_optical: Tensor | None = None ) -> dict[str, Tensor]: """Forward pass of the CROMA model. Args: x_sar: Input mini-batch of SAR images [B, 2, H, W]. x_optical: Input mini-batch of optical images [B, 12, H, W]. """ return_dict: dict[str, Tensor] = {} if 'sar' in self.modalities and x_sar is not None: sar_encodings = self.s1_encoder(imgs=x_sar, attn_bias=self.attn_bias) sar_GAP = self.s1_GAP_FFN(sar_encodings.mean(dim=1)) return_dict['sar_encodings'] = sar_encodings return_dict['sar_GAP'] = sar_GAP if 'optical' in self.modalities and x_optical is not None: optical_encodings = self.s2_encoder( imgs=x_optical, attn_bias=self.attn_bias ) optical_GAP = self.s2_GAP_FFN(optical_encodings.mean(dim=1)) return_dict['optical_encodings'] = optical_encodings return_dict['optical_GAP'] = optical_GAP if set(self.modalities) == {'sar', 'optical'}: joint_encodings = self.joint_encoder( x=sar_encodings, context=optical_encodings, relative_position_bias=self.attn_bias, ) joint_GAP = joint_encodings.mean(dim=1) return_dict['joint_encodings'] = joint_encodings return_dict['joint_GAP'] = joint_GAP return return_dict
def get_2dalibi(num_heads: int, num_patches: int) -> Tensor: """Get 2D relative position bias for the attention layer. Args: num_heads: Number of heads for the multi-head attention. num_patches: Number of patches. Returns: 2D relative position bias tensor. """ # inspired by: https://github.com/ofirpress/attention_with_linear_biases points = list( itertools.product( range(int(math.sqrt(num_patches))), range(int(math.sqrt(num_patches))) ) ) def get_slopes(n: int) -> list[float]: start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio**i for i in range(n)] slopes = torch.Tensor(get_slopes(num_heads)).unsqueeze(1) idxs = [] for p1 in points: for p2 in points: dist = math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) idxs.append(dist * slopes * -1) all_bias = torch.cat(idxs, dim=1) return all_bias.view(1, num_heads, num_patches, num_patches) class FFN(nn.Module): """Feed-forward network for the transformer.""" def __init__(self, dim: int, mult: int = 4, dropout: float = 0.0) -> None: """Initialize the feed-forward network. Args: dim: Dimension of the input. mult: Multiplier for the inner dimension of the feed-forward network. dropout: Dropout probability """ super().__init__() inner_dim = int(dim * mult) self.net = nn.Sequential( nn.Linear(dim, inner_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(inner_dim, dim), ) self.input_norm = nn.LayerNorm(dim) def forward(self, x: Tensor) -> Tensor: """Forward pass of the feed-forward network. Args: x: Input tensor. Returns: Output tensor. """ x = self.input_norm(x) x = self.net(x) return x class Attention(nn.Module): """Multi-head attention layer for the transformer.""" def __init__(self, dim: int, num_heads: int = 8, dropout: float = 0.0) -> None: """Initialize the multi-head attention layer. Args: dim: Dimension of the input. num_heads: Number of heads for the multi-head attention. dropout: Dropout probability. """ super().__init__() self.num_heads = num_heads assert dim % num_heads == 0, 'dim must be evenly divisible by num_heads' dim_head = int(dim / num_heads) self.scale = dim_head**-0.5 self.to_qkv = nn.Linear(dim, dim * 3, bias=False) self.to_out = nn.Linear(dim, dim) self.input_norm = nn.LayerNorm(dim) self.dropout = nn.Dropout(dropout) def forward(self, x: Tensor, relative_position_bias: Tensor) -> Tensor: """Forward pass of the multi-head attention layer. Args: x: Input tensor. relative_position_bias: Relative position bias tensor. Returns: Output tensor. """ x = self.input_norm(x) q, k, v = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map( lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), (q, k, v) ) attention_scores = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale attention_scores = attention_scores + relative_position_bias attn = attention_scores.softmax(dim=-1) attn = self.dropout(attn) x = einsum('b h i j, b h j d -> b h i d', attn, v) x = rearrange(x, 'b h n d -> b n (h d)') x = self.to_out(x) return x class CrossAttention(nn.Module): """Cross-attention layer for the transformer.""" def __init__(self, dim: int, num_heads: int = 8, dropout: float = 0.0) -> None: """Initialize the cross-attention layer. Args: dim: Dimension of the input. num_heads: Number of heads for the multi-head attention. dropout: Dropout probability. Raises: AssertionError: If the dimension is not evenly divisible by the number of heads. """ super().__init__() self.num_heads = num_heads assert dim % num_heads == 0, 'dim must be evenly divisible by num_heads' dim_head = int(dim / num_heads) self.scale = dim_head**-0.5 self.to_q = nn.Linear(dim, dim, bias=False) self.to_k = nn.Linear(dim, dim, bias=False) self.to_v = nn.Linear(dim, dim, bias=False) self.to_out = nn.Linear(dim, dim) self.input_norm = nn.LayerNorm(dim) self.dropout = nn.Dropout(dropout) def forward( self, x: Tensor, context: Tensor, relative_position_bias: Tensor ) -> Tensor: """Forward pass of the cross-attention layer. Args: x: Input tensor. context: Context tensor. relative_position_bias: Relative position bias tensor. Returns: Output tensor. """ x = self.input_norm(x) context = self.input_norm(context) q = self.to_q(x) k = self.to_k(context) v = self.to_v(context) q, k, v = map( lambda t: rearrange(t, 'b n (h d) -> b h n d', h=self.num_heads), (q, k, v) ) attention_scores = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale attention_scores = attention_scores + relative_position_bias attn = attention_scores.softmax(dim=-1) attn = self.dropout(attn) x = einsum('b h i j, b h j d -> b h i d', attn, v) x = rearrange(x, 'b h n d -> b n (h d)') x = self.to_out(x) return x class BaseTransformer(nn.Module): """Base transformer model.""" def __init__( self, dim: int, depth: int, num_heads: int = 8, attn_dropout: float = 0.0, ff_dropout: float = 0.0, ff_mult: int = 4, final_norm: bool = True, ) -> None: """Initialize the base transformer model. Args: dim: Dimension of the input. depth: Depth of the transformer. num_heads: Number of heads for the multi-head attention. attn_dropout: Dropout probability for the attention layer. ff_dropout: Dropout probability for the feed-forward network. ff_mult: Multiplier for the inner dimension of the feed-forward network. final_norm: Whether to apply a final layer normalization. """ super().__init__() self.final_norm = final_norm self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout), FFN(dim=dim, mult=ff_mult, dropout=ff_dropout), ] ) ) if self.final_norm: self.norm_out = nn.LayerNorm(dim) def forward(self, x: Tensor, relative_position_bias: Tensor) -> Tensor: """Forward pass of the base transformer model. Args: x: Input tensor. relative_position_bias: whether to use relative position bias. """ for self_attn, ffn in self.layers: # type: ignore[misc] x = self_attn(x, relative_position_bias) + x # type: ignore[has-type] x = ffn(x) + x # type: ignore[has-type] x = self.norm_out(x) if self.final_norm else x return x class BaseTransformerCrossAttn(nn.Module): """Base transformer model with cross-attention.""" def __init__( self, dim: int, depth: int, num_heads: int = 8, attn_dropout: float = 0.0, ff_dropout: float = 0.0, ff_mult: int = 4, ) -> None: """Initialize the base transformer model with cross-attention. Args: dim: Dimension of the input. depth: Depth of the transformer. num_heads: Number of heads for the multi-head attention. attn_dropout: Dropout probability for the attention layer. ff_dropout: Dropout probability for the feed-forward network. ff_mult: Multiplier for the inner dimension of the feed-forward network. """ super().__init__() self.layers = nn.ModuleList([]) for _ in range(depth): self.layers.append( nn.ModuleList( [ Attention(dim=dim, num_heads=num_heads, dropout=attn_dropout), CrossAttention( dim=dim, num_heads=num_heads, dropout=attn_dropout ), FFN(dim=dim, mult=ff_mult, dropout=ff_dropout), ] ) ) self.norm_out = nn.LayerNorm(dim) def forward( self, x: Tensor, context: Tensor, relative_position_bias: Tensor ) -> Tensor: """Forward pass of the base transformer model with cross-attention. Args: x: Input tensor. context: Context tensor. relative_position_bias: Relative position bias tensor. Returns: Output tensor. """ for self_attn, cross_attn, ffn in self.layers: # type: ignore[misc] x = self_attn(x, relative_position_bias) + x # type: ignore[has-type] x = cross_attn(x, context, relative_position_bias) + x # type: ignore[has-type] x = ffn(x) + x # type: ignore[has-type] x = self.norm_out(x) return x class ViT(nn.Module): """Vision Transformer model.""" def __init__(self, dim: int, depth: int, in_channels: int) -> None: """Initialize the vision transformer model. Args: dim: Dimension of the input. depth: Depth of the transformer. in_channels: Number of input channels. """ super().__init__() self.depth = depth self.in_channels = in_channels self.dim = dim self.num_heads = 16 # always 16, for base and large models self.patch_size = 8 # always 8, for base and large models pixels_per_patch = int(self.patch_size * self.patch_size * in_channels) self.linear_input = nn.Linear(pixels_per_patch, self.dim) self.transformer = BaseTransformer( dim=self.dim, depth=self.depth, num_heads=self.num_heads ) def forward(self, imgs: Tensor, attn_bias: Tensor) -> Tensor: """Forward pass of the vision transformer model. Args: imgs: Input tensor. attn_bias: Relative position bias tensor. Returns: Output tensor. """ imgs = rearrange( imgs, 'b c (h i) (w j) -> b (h w) (c i j)', i=self.patch_size, j=self.patch_size, ) # imgs is shape -> (bsz, num_patches, self.channels*self.patch_size*self.patch_size) imgs = self.linear_input(imgs) imgs = self.transformer(imgs, relative_position_bias=attn_bias) return imgs
[docs]class CROMABase_Weights(WeightsEnum): # type: ignore[misc] """CROMA base model weights. .. versionadded:: 0.7 """ CROMA_VIT = Weights( url='https://hf.co/torchgeo/croma/resolve/387883f08af79d777167519c57cd826eda89a16f/CROMA_base-0238d814.pt', transforms=None, meta={ 'dataset': 'SSL4EO', 'model': 'vit', 'publication': 'https://arxiv.org/abs/2311.00566', 'repo': 'https://github.com/antofuller/CROMA', 'ssl_method': 'croma', }, )
[docs]class CROMALarge_Weights(WeightsEnum): # type: ignore[misc] """CROMA large model weights. .. versionadded:: 0.7 """ CROMA_VIT = Weights( url='https://huggingface.co/torchgeo/croma/resolve/92cb1a0f4e34c6c01558baf070197c01255382f6/CROMA_large-921e69ad.pt', transforms=None, meta={ 'dataset': 'SSL4EO', 'model': 'vit', 'publication': 'https://arxiv.org/abs/2311.00566', 'repo': 'https://github.com/antofuller/CROMA', 'ssl_method': 'croma', }, )
def load_weights(model: CROMA, weights: WeightsEnum) -> None: """Load weights from a WeightsEnum object. Args: model: Model to load the weights into. weights: Weights to load. Raises: AssertionError: If there are missing or unexpected keys. """ state_dict = weights.get_state_dict(progress=True) missing_keys, unexpected_keys = [], [] if 'sar' in model.modalities: miss_key, unexp_key = model.s1_encoder.load_state_dict( state_dict['s1_encoder'], strict=False ) missing_keys.extend(miss_key) unexpected_keys.extend(unexp_key) miss_key, unexp_key = model.s1_GAP_FFN.load_state_dict( state_dict['s1_GAP_FFN'], strict=False ) missing_keys.extend(miss_key) unexpected_keys.extend(unexp_key) if 'optical' in model.modalities: miss_key, unexp_key = model.s2_encoder.load_state_dict( state_dict['s2_encoder'], strict=False ) missing_keys.extend(miss_key) unexpected_keys.extend(unexp_key) miss_key, unexp_key = model.s2_GAP_FFN.load_state_dict( state_dict['s2_GAP_FFN'], strict=False ) missing_keys.extend(miss_key) unexpected_keys.extend(unexp_key) if set(model.modalities) == {'sar', 'optical'}: miss_key, unexp_key = model.joint_encoder.load_state_dict( state_dict['joint_encoder'], strict=False ) missing_keys.extend(miss_key) unexpected_keys.extend(unexp_key) assert not missing_keys, f'Missing keys: {missing_keys}' assert not unexpected_keys, f'Unexpected keys: {unexpected_keys}'
[docs]def croma_base( weights: CROMABase_Weights | None = None, *args: Any, **kwargs: Any ) -> CROMA: """CROMA base model. If you use this model in your research, please cite the following paper: * https://arxiv.org/abs/2311.00566 .. versionadded:: 0.7 Args: weights: Pretrained weights to load. *args: Additional arguments to pass to :class:CROMA.` **kwargs: Additional keyword arguments to pass to :class:CROMA.` Returns: CROMA base model. """ kwargs |= { 'encoder_dim': 768, 'encoder_depth': 12, 'num_heads': 16, 'patch_size': 8, } model = CROMA(*args, **kwargs) if weights: load_weights(model, weights) return model
[docs]def croma_large( weights: CROMALarge_Weights | None = None, *args: Any, **kwargs: Any ) -> CROMA: """CROMA large model. If you use this model in your research, please cite the following paper: * https://arxiv.org/abs/2311.00566 .. versionadded:: 0.7 Args: weights: Pretrained weights to load. *args: Additional arguments to pass to :class:CROMA.` **kwargs: Additional keyword arguments to pass to :class:CROMA.` Returns: CROMA large model. """ kwargs |= { 'encoder_dim': 1024, 'encoder_depth': 24, 'num_heads': 16, 'patch_size': 8, } model = CROMA(*args, **kwargs) if weights: load_weights(model, weights) return model

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources