Shortcuts

Source code for torchgeo.models.copernicusfm

# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

# https://github.com/zhu-xlab/Copernicus-FM

"""Copernicus Foundation Model (Copernicus-FM)."""

import math
from collections.abc import Sequence
from typing import Any, Literal

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from einops import rearrange
from timm.models.vision_transformer import Block
from torch import Tensor, vmap
from torchvision.models._api import Weights, WeightsEnum

from ..samplers.utils import _to_tuple
from .dofa import FCResLayer, TransformerWeightGenerator


def resize_abs_pos_embed(
    pos_embed: Tensor,
    new_size: int | tuple[int, int],
    old_size: int | tuple[int, int],
    num_prefix_tokens: int = 1,
    interpolation: str = 'bicubic',
    antialias: bool = True,
) -> Tensor:
    """Resize absolute position embeddings to a target resolution via interpolation.

    Adapted from https://github.com/bwconrad/flexivit. Copyright (c) 2023 Ben Conrad.

    Args:
        pos_embed: Position embeddings tensor of size [b, n, d]
        new_size: Target [height, width] of embedding
        old_size: Original [height, width] of embedding
        num_prefix_tokens: Number of non-spatial prefix tokens (e.g., cls)
        interpolation: Resize interpolation type
        antialias: Whether to apply antialiasing resizing

    Returns:
        Resized pos_embed of size [b, n', d]
    """
    new_size = _to_tuple(new_size)
    old_size = _to_tuple(old_size)
    new_ntok = new_size[0] * new_size[1]

    # Return if no resize necessary
    if new_size == old_size:
        return pos_embed

    if num_prefix_tokens:
        posemb_prefix, pos_embed = (
            pos_embed[:, :num_prefix_tokens],
            pos_embed[:, num_prefix_tokens:],
        )
    else:
        posemb_prefix, pos_embed = None, pos_embed

    # Interpolate position embedding
    pos_embed = pos_embed.reshape(1, old_size[0], old_size[1], -1).permute(0, 3, 1, 2)
    pos_embed = F.interpolate(
        pos_embed, size=new_size, mode=interpolation, antialias=antialias
    )
    pos_embed = pos_embed.permute(0, 2, 3, 1).reshape(1, new_ntok, -1)

    # Add back extra prefix tokens
    if posemb_prefix is not None:
        pos_embed = torch.cat([posemb_prefix, pos_embed], dim=1)

    return pos_embed


def pi_resize_patch_embed(
    patch_embed: Tensor,
    new_patch_size: tuple[int, int],
    interpolation: str = 'bicubic',
    antialias: bool = True,
) -> Tensor:
    """Resample patch embeddings to a target resolution via pseudo-inverse resizing.

    Adapted from https://github.com/bwconrad/flexivit. Copyright (c) 2023 Ben Conrad.

    Args:
        patch_embed: Patch embedding parameters of size [d, c, h, w]
        new_patch_size: Target [height, width] of embedding
        interpolation: Resize interpolation type
        antialias: Whether to apply antialiasing resizing

    Returns:
        Resized pos_embed of size [d, c h', w']
    """
    assert len(patch_embed.shape) == 4, 'Patch embed kernel should be a 4D tensor'
    assert len(new_patch_size) == 2, 'New patch size should only be (height, width)'

    _, _, h, w = patch_embed.shape
    old_patch_size = (h, w)

    # Return original kernel if no resize is necessary
    if old_patch_size == new_patch_size:
        return patch_embed

    def resize(x: Tensor, shape: tuple[int, int]) -> Tensor:
        x = F.interpolate(
            x[None, None, ...], shape, mode=interpolation, antialias=antialias
        )
        return x[0, 0, ...]

    def calculate_pinv(
        old_shape: tuple[int, int], new_shape: tuple[int, int]
    ) -> Tensor:
        mat = []
        for i in range(np.prod(old_shape)):
            basis_vec = torch.zeros(old_shape)
            basis_vec[np.unravel_index(i, old_shape)] = 1.0
            mat.append(resize(basis_vec, new_shape).reshape(-1))
        resize_matrix = torch.stack(mat)
        pinv: Tensor = torch.linalg.pinv(resize_matrix)
        return pinv

    # Calculate pseudo-inverse of resize matrix
    resize_matrix_pinv = calculate_pinv(old_patch_size, new_patch_size)
    resize_matrix_pinv = resize_matrix_pinv.to(patch_embed.device)

    def resample_patch_embed(patch_embed: Tensor) -> Tensor:
        h, w = new_patch_size
        resampled_kernel = resize_matrix_pinv @ patch_embed.reshape(-1)
        return rearrange(resampled_kernel, '(h w) -> h w', h=h, w=w)

    v_resample_patch_embed = vmap(vmap(resample_patch_embed, 0, 0), 1, 1)

    patch_embed = v_resample_patch_embed(patch_embed)
    return patch_embed


class FourierExpansion(nn.Module):
    """A Fourier series-style expansion into a high-dimensional space.

    Adapted from https://github.com/microsoft/aurora.
    Copyright (c) Microsoft Corporation.
    """

    def __init__(self, lower: float, upper: float, assert_range: bool = True) -> None:
        """Initialise.

        Args:
            lower: Lower wavelength.
            upper: Upper wavelength.
            assert_range: Assert that the encoded tensor is within the specified
                wavelength range.
        """
        super().__init__()
        self.lower = lower
        self.upper = upper
        self.assert_range = assert_range

    def forward(self, x: Tensor, d: int) -> Tensor:
        """Perform the expansion.

        Adds a dimension of length `d` to the end of the shape of `x`.

        Args:
            x: Input to expand of shape `(..., n)`. All elements of `x` must
               lie within `[self.lower, self.upper]` if `self.assert_range` is `True`.
            d: Dimensionality. Must be a multiple of two.

        Raises:
            AssertionError: If `self.assert_range` is `True` and not all elements of `x`
                are not within `[self.lower, self.upper]`.
            ValueError: If `d` is not a multiple of two.

        Returns:
            Fourier series-style expansion of `x` of shape `(..., n, d)`.
        """
        # If the input is not within the configured range, the embedding might be ambiguous!
        in_range = torch.logical_and(
            self.lower <= x.abs(), torch.all(x.abs() <= self.upper)
        )
        # Allow zeros to pass through.
        in_range_or_zero = torch.all(torch.logical_or(in_range, x == 0))
        if self.assert_range and not in_range_or_zero:
            raise AssertionError(
                f'The input tensor is not within the configured range'
                f' `[{self.lower}, {self.upper}]`.'
            )

        # We will use half of the dimensionality for `sin` and the other half for `cos`.
        if not (d % 2 == 0):
            raise ValueError('The dimensionality must be a multiple of two.')

        # Always perform the expansion with `float64`s to avoid numerical accuracy shenanigans.
        x = x.double()

        wavelengths = torch.logspace(
            math.log10(self.lower),
            math.log10(self.upper),
            d // 2,
            base=10,
            device=x.device,
            dtype=x.dtype,
        )
        prod = torch.einsum('...i,j->...ij', x, 2 * np.pi / wavelengths)
        encoding = torch.cat((torch.sin(prod), torch.cos(prod)), dim=-1)

        return encoding.float()  # Cast to `float32` to avoid incompatibilities.


class DynamicPatchEmbed(nn.Module):
    """Dynamic patch embedding with spectral or variable hypernetworks.

    Adapted from DOFA.
    """

    def __init__(
        self,
        hyper_dim: int = 128,
        kernel_size: int = 16,
        embed_dim: int = 1024,
        input_mode: Literal['spectral', 'variable'] = 'spectral',
    ) -> None:
        """Initialize a new DynamicPatchEmbed instance.

        Args:
            hyper_dim: dim for wavelength/bandwidth/varname encoding.
            kernel_size: Kernel size for the patch embedding (convolution) layer.
            embed_dim: Embedding dimension.
            input_mode: Type of hypernetwork to use. Options: 'spectral' or 'variable'.
                'spectral' uses Fourier encodings for wavelength and bandwidth;
                'variable' uses a language embedding for variable names.
        """
        super().__init__()
        self.input_mode = input_mode
        self.kernel_size = kernel_size
        self.hyper_dim = hyper_dim
        self.embed_dim = embed_dim
        self._num_kernel = self.kernel_size * self.kernel_size * self.embed_dim
        self.patch_size = (kernel_size, kernel_size)
        self.num_patches = -1

        if self.input_mode == 'spectral':
            # Spectral hypernetwork: Fourier encoding for wavelength and bandwidth.
            # min wavelength: ultraviolet light (100 nm)
            # max wavelength: radio waves (1 m)
            # min bandwidth: s2 ~ 10 nm
            # max bandwidth: s1 ~ 1 m
            self.spectrum_central_expansion = FourierExpansion(100, 1e9)
            self.spectrum_bandwidth_expansion = FourierExpansion(1, 1e9)
        elif self.input_mode == 'variable':
            # Variable hypernetwork: Language embedding for variable names.
            self.language_proj = nn.Linear(2048, self.hyper_dim)

        self.weight_generator = TransformerWeightGenerator(
            hyper_dim, self._num_kernel, embed_dim
        )
        self.scaler = 0.01
        self.fclayer = FCResLayer(hyper_dim)
        self._init_weights()

    def _get_weights(self, waves: Tensor) -> Tensor:
        """Use the dynamic weight generator.

        Args:
            waves: Spectral wavelengths.

        Returns:
            Dynamic weights.
        """
        dynamic_weights: Tensor = self.weight_generator(waves)
        return dynamic_weights

    def weight_init(self, m: object) -> None:
        """Initialize weights of a single layer.

        Args:
            m: A single layer.
        """
        if isinstance(m, nn.Linear):
            init.xavier_uniform_(m.weight)
            if m.bias is not None:
                m.bias.data.fill_(0.01)

    def _init_weights(self) -> None:
        """Initialize weights of all layers."""
        self.weight_generator.apply(self.weight_init)
        self.fclayer.apply(self.weight_init)

    def forward(
        self,
        x: Tensor,
        wavelengths: Tensor | None = None,
        bandwidths: Tensor | None = None,
        language_embed: Tensor | None = None,
        kernel_size: int | None = None,
    ) -> Tensor:
        """Forward pass.

        For input_mode=='spectral', `wavelengths` and `bandwidths` must be provided.
        For input_mode=='variable', `language_embed` must be provided.

        Args:
            x: Input image tensor (B, C, H, W).
            wavelengths: Wavelengths in nm (required if input_mode=='spectral').
            bandwidths: Bandwidths in nm (required if input_mode=='spectral').
            language_embed: Language embedding tensor from Llama 3.2 1B (length 2048).
            kernel_size: If provided and differs from the initialized kernel size,
                the generated patch embed kernel weights are resized accordingly.

        Returns:
            Output after patch embedding (B, N, D).

        Raises:
            ValueError: When *input_mode=='spectral'* and *wavelengths* or *bandwidths* is missing,
                or when *input_mode=='variable'* and *language_embed* is missing.
        """
        if self.input_mode == 'spectral':
            if wavelengths is None or bandwidths is None:
                msg = 'For spectral hypernet, wavelengths and bandwidths must be provided.'
                raise ValueError(msg)

            emb_central = self.spectrum_central_expansion(wavelengths, self.hyper_dim)
            emb_bandwidth = self.spectrum_bandwidth_expansion(
                bandwidths, self.hyper_dim
            )
            waves = emb_central + emb_bandwidth
        elif self.input_mode == 'variable':
            if language_embed is None:
                msg = 'For variable hypernet, language_embed must be provided.'
                raise ValueError(msg)

            # Expand dims to match batch size.
            waves = self.language_proj(language_embed.unsqueeze(0))

        waves = self.fclayer(waves)
        weight, bias = self._get_weights(waves)
        inplanes = waves.size(0)
        dynamic_weight = weight.view(
            inplanes, self.kernel_size, self.kernel_size, self.embed_dim
        )
        dynamic_weight = dynamic_weight.permute(3, 0, 1, 2)

        if kernel_size is not None and self.kernel_size != kernel_size:
            dynamic_weight = pi_resize_patch_embed(
                dynamic_weight, (kernel_size, kernel_size)
            )
        else:
            kernel_size = self.kernel_size

        if bias is not None:
            bias = bias.view(self.embed_dim) * self.scaler

        weights = dynamic_weight * self.scaler

        dynamic_out = F.conv2d(
            x, weights, bias=bias, stride=kernel_size, padding=1, dilation=1
        )
        x = dynamic_out.flatten(2).transpose(1, 2)
        return x


[docs]class CopernicusFM(nn.Module): """CopernicusFM: VisionTransformer backbone. Example: **1. Spectral Mode (Using Wavelength and Bandwidth):** >>> model = CopernicusFM() >>> x = torch.randn(1, 4, 224, 224) # input image >>> metadata = torch.full((1, 4), float('nan')) # [lon (degree), lat (degree), delta_time (days since 1970/1/1), patch_token_area (km^2)], assume unknown >>> wavelengths = [490, 560, 665, 842] # wavelength (nm): B,G,R,NIR (Sentinel 2) >>> bandwidths = [65, 35, 30, 115] # bandwidth (nm): B,G,R,NIR (Sentinel 2) >>> kernel_size = 16 # expected patch size >>> input_mode = 'spectral' >>> logit = model(x, metadata, wavelengths=wavelengths, bandwidths=bandwidths, input_mode=input_mode, kernel_size=kernel_size) >>> print(logit.shape) **2. Variable Mode (Using language embedding):** >>> model = CopernicusFM() >>> varname = 'Sentinel 5P Nitrogen Dioxide' # variable name (as input to a LLM for language embed) >>> x = torch.randn(1, 1, 56, 56) # input image >>> metadata = torch.full((1, 4), float('nan')) # [lon (degree), lat (degree), delta_time (days since 1970/1/1), patch_token_area (km^2)], assume unknown >>> language_embed = torch.randn(2048) # language embedding: encode varname with a LLM (e.g. Llama) >>> kernel_size = 4 # expected patch size >>> input_mode = 'variable' >>> logit = model(x, metadata, language_embed=language_embed, input_mode=input_mode, kernel_size=kernel_size) >>> print(logit.shape) """
[docs] def __init__( self, img_size: int = 224, patch_size: int = 16, drop_rate: float = 0.0, embed_dim: int = 1024, depth: int = 24, num_heads: int = 16, hyper_dim: int = 128, num_classes: int = 0, global_pool: bool = True, mlp_ratio: float = 4.0, norm_layer: type[nn.Module] = nn.LayerNorm, ) -> None: """Initialize a new CopernicusFM instance. Args: img_size: Input image size. patch_size: Patch size. drop_rate: Head dropout rate. embed_dim: Transformer embedding dimension. depth: Depth of transformer. num_heads: Number of attention heads. hyper_dim: Dimensions of dynamic weight generator. num_classes: Number of classes for classification head. global_pool: Whether or not to perform global pooling. mlp_ratio: Ratio of MLP hidden dim to embedding dim. norm_layer: Normalization layer. """ super().__init__() self.hyper_dim = hyper_dim self.global_pool = global_pool if self.global_pool: norm_layer = norm_layer embed_dim = embed_dim self.fc_norm = norm_layer(embed_dim) else: self.norm = norm_layer(embed_dim) self.patch_embed_spectral = DynamicPatchEmbed( hyper_dim=hyper_dim, kernel_size=patch_size, embed_dim=embed_dim, input_mode='spectral', ) self.patch_embed_variable = DynamicPatchEmbed( hyper_dim=hyper_dim, kernel_size=patch_size, embed_dim=embed_dim, input_mode='variable', ) self.num_patches = (img_size // patch_size) ** 2 self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) # Fixed sin-cos embedding self.pos_embed = nn.Parameter( torch.zeros(1, self.num_patches + 1, embed_dim), requires_grad=False ) self.coord_expansion = FourierExpansion(0.0001, 720) self.scale_expansion = FourierExpansion(0.001, 5.1e8) # 1 m2 to 5.1e8 km2 # 1 to 365.25 days, enable more than 1 year self.time_expansion = FourierExpansion(1, 365.25, assert_range=False) self.coord_fc = nn.Linear(embed_dim, embed_dim) self.scale_fc = nn.Linear(embed_dim, embed_dim) self.time_fc = nn.Linear(embed_dim, embed_dim) # if metadata is not available, set to a learned parameter self.coord_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.scale_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.time_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.blocks = nn.ModuleList( [ Block( embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, ) for i in range(depth) ] ) self.head_drop = nn.Dropout(drop_rate) self.head = ( nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity() )
[docs] def get_coord_pos_embed(self, lons: Tensor, lats: Tensor, embed_dim: int) -> Tensor: """Geospatial coordinate position embedding. Args: lons: Longitudes (x). lats: Latitudes (y). embed_dim: Embedding dimension. Returns: Coordinate position embedding. """ coord_embed_lon = self.coord_expansion(lons + 180, embed_dim // 2) coord_embed_lat = self.coord_expansion(lats + 90, embed_dim // 2) coord_embed = torch.cat([coord_embed_lon, coord_embed_lat], dim=-1) if coord_embed.shape[-1] < embed_dim: # pad zeros coord_embed = torch.cat( ( coord_embed, torch.zeros( coord_embed.shape[0], embed_dim - coord_embed.shape[-1], device=coord_embed.device, ), ), dim=-1, ) return coord_embed.unsqueeze(1) # [B,1,D]
[docs] def get_area_pos_embed(self, areas: Tensor, embed_dim: int) -> Tensor: """Geospatial area position embedding. Args: areas: Spatial areas. embed_dim: Embedding dimension. Returns: Area position embedding. """ scale_embed: Tensor = self.scale_expansion(areas, embed_dim) # B, D scale_embed = scale_embed.unsqueeze(1) # [B,1,D] return scale_embed
[docs] def get_time_pos_embed(self, times: Tensor, embed_dim: int) -> Tensor: """Geotemporal position embedding. Args: times: Timestamps. embed_dim: Embedding dimension. Returns: Temporal position embedding. """ time_embed: Tensor = self.time_expansion(times, embed_dim) # B, D time_embed = time_embed.unsqueeze(1) # [B,1,D] return time_embed
[docs] def forward_features( self, x: Tensor, metadata: Tensor, wavelengths: Sequence[float] | None = None, bandwidths: Sequence[float] | None = None, language_embed: Tensor | None = None, input_mode: Literal['spectral', 'variable'] = 'spectral', kernel_size: int | None = None, ) -> Tensor: """Forward pass of the feature embedding layer. Args: x: Input mini-batch. metadata: Longitudes (degree), latitudes (degree), times (days since 1970/1/1), and areas (km^2) of each patch. Use NaN for unknown metadata. wavelengths: Wavelengths of each spectral band (nm). Only used if *input_mode=='spectral'*. bandwidths: Bandwidths in nm. Only used if *input_mode=='spectral'*. language_embed: Language embedding tensor from Llama 3.2 1B (length 2048). Only used if *input_mode=='variable'*. input_mode: One of 'spectral' or 'variable'. kernel_size: If provided and differs from the initialized kernel size, the generated patch embed kernel weights are resized accordingly. Returns: Output mini-batch. """ if input_mode == 'spectral': wvs = torch.tensor(wavelengths, device=x.device).float() bws = torch.tensor(bandwidths, device=x.device).float() x = self.patch_embed_spectral( x, wavelengths=wvs, bandwidths=bws, kernel_size=kernel_size ) elif input_mode == 'variable': x = self.patch_embed_variable( x, language_embed=language_embed, kernel_size=kernel_size ) # resize pos embed num_patches = x.size(1) num_patches_sqrt = int(math.sqrt(num_patches)) num_patches_sqrt_origin = int(math.sqrt(self.num_patches)) pos_embed = resize_abs_pos_embed( self.pos_embed, num_patches_sqrt, (num_patches_sqrt_origin, num_patches_sqrt_origin), num_prefix_tokens=1, ) # coord, scale and time pos embed lons, lats, times, areas = ( metadata[:, 0], metadata[:, 1], metadata[:, 2], metadata[:, 3], ) embed_dim = pos_embed.shape[-1] if torch.isnan(lons).any() or torch.isnan(lats).any(): coord_embed: nn.Parameter | Tensor = self.coord_token else: coord_embed = self.get_coord_pos_embed(lons, lats, embed_dim) coord_embed = self.coord_fc(coord_embed) if torch.isnan(areas).any(): area_embed: nn.Parameter | Tensor = self.scale_token else: area_embed = self.get_area_pos_embed(areas, embed_dim) area_embed = self.scale_fc(area_embed) if torch.isnan(times).any(): time_embed: nn.Parameter | Tensor = self.time_token else: time_embed = self.get_time_pos_embed(times, embed_dim) time_embed = self.time_fc(time_embed) pos_embed = pos_embed + coord_embed + area_embed + time_embed # add pos embed w/o cls token x = x + pos_embed[:, 1:, :] # append cls token cls_token = self.cls_token + pos_embed[:, :1, :] cls_tokens = cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_tokens, x), dim=1) # apply Transformer blocks for block in self.blocks: x = block(x) if self.global_pool: x = x[:, 1:, :].mean(dim=1) # global pool without cls token outcome: Tensor = self.fc_norm(x) else: x = self.norm(x) outcome = x[:, 0] return outcome
[docs] def forward_head(self, x: Tensor, pre_logits: bool = False) -> Tensor: """Forward pass of the attention head. Args: x: Input mini-batch. pre_logits: Whether or not to return the layer before logits are computed. Returns: Output mini-batch. """ x = self.head_drop(x) return x if pre_logits else self.head(x)
[docs] def forward( self, x: Tensor, metadata: Tensor, wavelengths: Sequence[float] | None = None, bandwidths: Sequence[float] | None = None, language_embed: Tensor | None = None, input_mode: Literal['spectral', 'variable'] = 'spectral', kernel_size: int | None = None, ) -> Tensor: """Forward pass of the model. Args: x: Input mini-batch. metadata: Longitudes (degree), latitudes (degree), times (days since 1970/1/1), and areas (km^2) of each patch. Use NaN for unknown metadata. wavelengths: Wavelengths of each spectral band (nm). Only used if *input_mode=='spectral'*. bandwidths: Bandwidths in nm. Only used if *input_mode=='spectral'*. language_embed: Language embedding tensor from Llama 3.2 1B (length 2048). Only used if *input_mode=='variable'*. input_mode: One of 'spectral' or 'variable'. kernel_size: If provided and differs from the initialized kernel size, the generated patch embed kernel weights are resized accordingly. Returns: Output mini-batch. """ fx = self.forward_features( x, metadata, wavelengths, bandwidths, language_embed, input_mode, kernel_size, ) x = self.forward_head(fx) return x
[docs]class CopernicusFM_Base_Weights(WeightsEnum): # type: ignore[misc] """Copernicus-FM-base weights.""" CopernicusFM_ViT = Weights( url='https://huggingface.co/torchgeo/copernicus-fm/resolve/f395812cc990ba25a451dbb9c9e6d95c8482947e/CopernicusFM_ViT_base_varlang-085350e4.pth', transforms=None, meta={ 'dataset': 'Copernicus-Pretrain', 'model': 'copernicusfm_base', 'publication': 'https://arxiv.org/abs/2503.11849', 'repo': 'https://github.com/zhu-xlab/Copernicus-FM', 'ssl_method': 'mae+distill', }, )
[docs]def copernicusfm_base( weights: CopernicusFM_Base_Weights | None = None, *args: Any, **kwargs: Any ) -> CopernicusFM: """CopernicusFM vit-base model. If you use this model in your research, please cite the following paper: * https://arxiv.org/abs/2503.11849 .. versionadded:: 0.7 Args: weights: Pre-trained model weights to use. *args: Additional arguments to pass to :class:`CopernicusFM`. **kwargs: Additional keyword arguments to pass to :class:`CopernicusFM`. Returns: A CopernicusFM base model. """ kwargs |= {'embed_dim': 768, 'depth': 12, 'num_heads': 12} model = CopernicusFM(*args, **kwargs) if weights: missing_keys, unexpected_keys = model.load_state_dict( weights.get_state_dict(progress=True), strict=False ) # Both fc_norm and head are generated dynamically assert set(missing_keys) <= { 'fc_norm.weight', 'fc_norm.bias', 'head.weight', 'head.bias', } assert not unexpected_keys return model

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources