Shortcuts

Source code for torchgeo.models.ltae

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

# Copyright (c) 2020 VSainteuf (Vivien Sainte Fare Garnot)

"""Lightweight Temporal Attention Encoder (L-TAE) model."""

import math
from collections.abc import Sequence

import torch
import torch.nn as nn
from torch import Tensor


[docs]class LTAE(nn.Module): """Lightweight Temporal Attention Encoder (L-TAE). This model implements a lightweight temporal attention encoder that processes time series data using a multi-head attention mechanism. It is designed to efficiently encode temporal sequences into fixed-length embeddings. If you use this model in your research, please cite the following paper: * https://arxiv.org/abs/2007.00586 .. versionadded:: 0.8 """
[docs] def __init__( self, in_channels: int = 128, n_head: int = 16, d_k: int = 8, n_neurons: Sequence[int] = (256, 128), dropout: float = 0.2, d_model: int | None = 256, T: int = 1000, len_max_seq: int = 24, positions: Sequence[int] | None = None, ) -> None: """Sequence-to-embedding encoder. Args: in_channels: Number of channels of the input embeddings n_head: Number of attention heads d_k: Dimension of the key and query vectors n_neurons: Defines the dimensions of the successive feature spaces of the MLP that processes the concatenated outputs of the attention heads dropout: dropout T: Period to use for the positional encoding len_max_seq: Maximum sequence length, used to pre-compute the positional encoding table positions: List of temporal positions to use instead of position in the sequence d_model: If specified, the input tensors will first processed by a fully connected layer to project them into a feature space of dimension d_model """ super().__init__() self.in_channels = in_channels self.positions = positions self.n_neurons = n_neurons self.d_model = d_model if d_model is not None else in_channels self.inconv: nn.Sequential | None = None if d_model is not None: self.inconv = nn.Sequential( nn.Conv1d(in_channels, d_model, 1), nn.LayerNorm([d_model, len_max_seq]) ) # Use PyTorch's built-in positional encoding self.pos_encoder = PositionalEncoding(self.d_model, dropout, T) # Use PyTorch's built-in MultiheadAttention self.attention = nn.MultiheadAttention( embed_dim=self.d_model, num_heads=n_head, dropout=dropout, batch_first=True ) self.inlayernorm = nn.LayerNorm(self.in_channels) self.outlayernorm = nn.LayerNorm(n_neurons[-1]) assert self.n_neurons[0] == self.d_model activation = nn.ReLU(inplace=True) layers = [] for i in range(len(self.n_neurons) - 1): layers.extend( [ nn.Linear(self.n_neurons[i], self.n_neurons[i + 1]), nn.BatchNorm1d(self.n_neurons[i + 1]), activation, ] ) self.mlp = nn.Sequential(*layers) self.dropout = nn.Dropout(dropout, inplace=True)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass of the model. Args: x: Input tensor of shape (batch_size, seq_len, in_channels) Returns: Output tensor of shape (batch_size, n_neurons[-1]) """ sz_b, seq_len, d = x.shape x = self.inlayernorm(x) if self.inconv is not None: x = self.inconv(x.permute(0, 2, 1)).permute(0, 2, 1) # Apply positional encoding x = self.pos_encoder(x) # Apply multi-head attention # PyTorch's MultiheadAttention expects query, key, value attention_output, _ = self.attention(x, x, x) # Process through MLP # Take the mean over the sequence dimension to get a fixed-size representation mlp_input = attention_output.mean(dim=1) # (batch_size, d_model) output: Tensor = self.outlayernorm(self.dropout(self.mlp(mlp_input))) return output
class PositionalEncoding(nn.Module): """Positional encoding module using sinusoidal functions.""" def __init__(self, d_model: int, dropout: float = 0.1, T: int = 1000) -> None: """Initialize the positional encoding. Args: d_model: The dimension of the embeddings dropout: Dropout rate T: Period for the sinusoidal functions """ super().__init__() self.dropout = nn.Dropout(p=dropout) position = torch.arange(T).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(T) / d_model)) pe = torch.zeros(1, T, d_model) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer('pe', pe) def forward(self, x: Tensor) -> Tensor: """Forward pass. Args: x: Input tensor of shape (batch_size, seq_len, d_model) Returns: Output tensor with positional encoding added """ # Get positional encoding up to the sequence length pe = self.pe[:, : x.size(1)] # type: ignore[index] output: Tensor = self.dropout(x + pe) return output

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources