Shortcuts
Open in Studio Open in Colab
[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.

Earthquake Prediction

Written by Daniele Rege Cambrin

Introduction

The objective of this tutorial is to go through the QuakeSet dataset and cover the following topics:

  • How to use TorchGeo data modules to load datasets and plot samples;

  • How to use TorchGeo pre-trained model embeddings to train a classical model (e.g., Random Forest);

  • How to train a new TorchGeo deep model using tasks and trainer;

Environment

For the environment, we will install the torchgeo, h5py, and scikit-learn packages.

[ ]:
%pip install torchgeo h5py scikit-learn

Imports

[ ]:
import tempfile
from collections import defaultdict
from pathlib import Path

import numpy as np
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from tqdm import tqdm

from torchgeo.datamodules import QuakeSetDataModule
from torchgeo.models import ResNet50_Weights, resnet50
from torchgeo.trainers import ClassificationTask

Dataset

We will use the QuakeSet dataset (licensed under OpenRAIL License), which has patches from around the world before and after an earthquake, with corresponding negative examples.

The dataset uses SAR imagery from Sentinel-1 satellite with 10m of spatial resolution. The task is to predict for each couple of images if an earthquake occurs between them.

[ ]:
num_workers = 1
batch_size = 4
max_epochs = 10
fast_dev_run = False
[ ]:
tmp_path = Path(tempfile.gettempdir())
# The data module has already been implemented in TorchGeo, so we can use it
datamodule = QuakeSetDataModule(
    batch_size=batch_size, num_workers=num_workers, download=True, root=tmp_path
)
[ ]:
# This will download the dataset
datamodule.prepare_data()

Batch Visualization

The QuakeSetDataModule already has a plot function implemented to show the samples.

Remember to call setup before using it with fit or test (otherwise, you will get an error).

[ ]:
datamodule.setup('fit')
datamodule.setup('test')
sample = datamodule.val_dataset[0]
datamodule.plot(sample)

Train ML model on Pretrained embeddings

The first approach uses a pre-trained deep-learning model to compute the embeddings to train a classical machine learning model.

First, we have set the constants to select which device to use and the dataset percentage to use (default 10%).

[ ]:
# Change to "cpu" if you don't have a GPU
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Percentage of samples to use
PCT_SAMPLES = 0.1
[ ]:
# We take a subset of the dataset to speed up training
datamodule.train_dataset.data = datamodule.train_dataset.data[
    : int(len(datamodule.train_dataset.data) * PCT_SAMPLES)
]
datamodule.val_dataset.data = datamodule.val_dataset.data[
    : int(len(datamodule.val_dataset.data) * PCT_SAMPLES)
]
datamodule.test_dataset.data = datamodule.test_dataset.data[
    : int(len(datamodule.test_dataset.data) * PCT_SAMPLES)
]

Now, we load a ResNet50 pre-trained on Sentinel-1 images and define the function to make inferences.

[ ]:
model_transform = ResNet50_Weights.SENTINEL1_ALL_MOCO.transforms
rn_model = resnet50(ResNet50_Weights.SENTINEL1_ALL_MOCO).to(DEVICE).eval()
[ ]:
def infer(batch):
    img = batch['image'].to(DEVICE)
    labels = batch['label']
    # Each image has 4 channels (two channels for pre-event image, and two for post-event).
    # We need to split it into two images with two channels each.
    pre = model_transform({'image': img[:, :2]})['image']
    post = model_transform({'image': img[:, 2:]})['image']
    with torch.no_grad():
        embs = torch.concat([rn_model(pre), rn_model(post)], axis=1).cpu().numpy()
    return embs, labels

Now, we run inference on the training and test sets to compute the embeddings. The model will be used as a feature extractor.

[ ]:
embeddings = defaultdict(list)
labels = defaultdict(list)
# We iterate over the train_dataloader
for i, batch in tqdm(enumerate(datamodule.train_dataloader()), desc='Train'):
    for j, (emb, lab) in enumerate(zip(*infer(batch))):
        embeddings['train'].append(emb)
        labels['train'].append(lab)
# We iterate over the test_dataloader
for i, batch in tqdm(enumerate(datamodule.test_dataloader()), desc='Test'):
    for j, (emb, lab) in enumerate(zip(*infer(batch))):
        embeddings['test'].append(emb)
        labels['test'].append(lab)
# Now we merge the embeddings and labels into a single array
embeddings = {k: np.stack(v) for k, v in embeddings.items()}
labels = {k: np.array(v) for k, v in labels.items()}

Now, we can fit a classical model (e.g., Random Forest) using the embeddings as features and the labels as targets.

[ ]:
# Train a RandomForest classifier
clf = RandomForestClassifier(n_estimators=100, n_jobs=-1)
clf.fit(embeddings['train'], labels['train'])
# Evaluate the classifier on the test set
preds = clf.predict(embeddings['test'])
print(classification_report(labels['test'], preds))

Training a deep model from scratch

The second approach requires training a deep neural network from scratch. To this end, we can use the TorchGeo’s ClassificationTask and Lightning’s Trainer to simplify the training.

Remember to set the in_channels parameter to 4 since we are concatenating two (pre and post-event) two-channel images.

[ ]:
task = ClassificationTask(
    model='resnet18', in_channels=4, task='binary', loss='bce', lr=0.0001
)
[ ]:
# Set up the trainer logger and checkpoint callback
default_root_dir = Path(tempfile.gettempdir()) / 'experiments'
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss', dirpath=default_root_dir, save_top_k=1, save_last=True
)
logger = CSVLogger(save_dir=default_root_dir, name='tutorial_logs')
# Set up the trainer
trainer = Trainer(
    accelerator='auto',
    callbacks=[checkpoint_callback],
    log_every_n_steps=1,
    logger=logger,
    max_epochs=max_epochs,
    limit_train_batches=PCT_SAMPLES,
    limit_val_batches=PCT_SAMPLES,
    limit_test_batches=PCT_SAMPLES,
    fast_dev_run=fast_dev_run,
)
[ ]:
# Fit the model
trainer.fit(model=task, datamodule=datamodule)
[ ]:
# Test the model and print the results
trainer.test(model=task, datamodule=datamodule)

This tutorial shows how to use the TorchGeo pretrained models in place of full training from scratch. You can see a full list of available models in the TorchGeo documentation.

For a review of applications of deep learning in earthquake engineering see https://arxiv.org/abs/2405.09021

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources