[ ]:
# Copyright (c) TorchGeo Contributors. All rights reserved.
# Licensed under the MIT License.
Earthquake Prediction¶
Written by Daniele Rege Cambrin
Introduction¶
The objective of this tutorial is to go through the QuakeSet dataset and cover the following topics:
How to use TorchGeo data modules to load datasets and plot samples;
How to use TorchGeo pre-trained model embeddings to train a classical model (e.g., Random Forest);
How to train a new TorchGeo deep model using tasks and trainer;
Environment¶
For the environment, we will install the torchgeo, h5py, and scikit-learn packages.
[ ]:
%pip install torchgeo h5py scikit-learn
Imports¶
[ ]:
import tempfile
from collections import defaultdict
from pathlib import Path
import numpy as np
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report
from tqdm import tqdm
from torchgeo.datamodules import QuakeSetDataModule
from torchgeo.models import ResNet50_Weights, resnet50
from torchgeo.trainers import ClassificationTask
Dataset¶
We will use the QuakeSet dataset (licensed under OpenRAIL License), which has patches from around the world before and after an earthquake, with corresponding negative examples.
The dataset uses SAR imagery from Sentinel-1 satellite with 10m of spatial resolution. The task is to predict for each couple of images if an earthquake occurs between them.
[ ]:
num_workers = 1
batch_size = 4
max_epochs = 10
fast_dev_run = False
[ ]:
tmp_path = Path(tempfile.gettempdir())
# The data module has already been implemented in TorchGeo, so we can use it
datamodule = QuakeSetDataModule(
batch_size=batch_size, num_workers=num_workers, download=True, root=tmp_path
)
[ ]:
# This will download the dataset
datamodule.prepare_data()
Batch Visualization¶
The QuakeSetDataModule already has a plot
function implemented to show the samples.
Remember to call setup
before using it with fit or test (otherwise, you will get an error).
[ ]:
datamodule.setup('fit')
datamodule.setup('test')
sample = datamodule.val_dataset[0]
datamodule.plot(sample)
Train ML model on Pretrained embeddings¶
The first approach uses a pre-trained deep-learning model to compute the embeddings to train a classical machine learning model.
First, we have set the constants to select which device to use and the dataset percentage to use (default 10%).
[ ]:
# Change to "cpu" if you don't have a GPU
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# Percentage of samples to use
PCT_SAMPLES = 0.1
[ ]:
# We take a subset of the dataset to speed up training
datamodule.train_dataset.data = datamodule.train_dataset.data[
: int(len(datamodule.train_dataset.data) * PCT_SAMPLES)
]
datamodule.val_dataset.data = datamodule.val_dataset.data[
: int(len(datamodule.val_dataset.data) * PCT_SAMPLES)
]
datamodule.test_dataset.data = datamodule.test_dataset.data[
: int(len(datamodule.test_dataset.data) * PCT_SAMPLES)
]
Now, we load a ResNet50 pre-trained on Sentinel-1 images and define the function to make inferences.
[ ]:
model_transform = ResNet50_Weights.SENTINEL1_ALL_MOCO.transforms
rn_model = resnet50(ResNet50_Weights.SENTINEL1_ALL_MOCO).to(DEVICE).eval()
[ ]:
def infer(batch):
img = batch['image'].to(DEVICE)
labels = batch['label']
# Each image has 4 channels (two channels for pre-event image, and two for post-event).
# We need to split it into two images with two channels each.
pre = model_transform({'image': img[:, :2]})['image']
post = model_transform({'image': img[:, 2:]})['image']
with torch.no_grad():
embs = torch.concat([rn_model(pre), rn_model(post)], axis=1).cpu().numpy()
return embs, labels
Now, we run inference on the training and test sets to compute the embeddings. The model will be used as a feature extractor.
[ ]:
embeddings = defaultdict(list)
labels = defaultdict(list)
# We iterate over the train_dataloader
for i, batch in tqdm(enumerate(datamodule.train_dataloader()), desc='Train'):
for j, (emb, lab) in enumerate(zip(*infer(batch))):
embeddings['train'].append(emb)
labels['train'].append(lab)
# We iterate over the test_dataloader
for i, batch in tqdm(enumerate(datamodule.test_dataloader()), desc='Test'):
for j, (emb, lab) in enumerate(zip(*infer(batch))):
embeddings['test'].append(emb)
labels['test'].append(lab)
# Now we merge the embeddings and labels into a single array
embeddings = {k: np.stack(v) for k, v in embeddings.items()}
labels = {k: np.array(v) for k, v in labels.items()}
Now, we can fit a classical model (e.g., Random Forest) using the embeddings as features and the labels as targets.
[ ]:
# Train a RandomForest classifier
clf = RandomForestClassifier(n_estimators=100, n_jobs=-1)
clf.fit(embeddings['train'], labels['train'])
# Evaluate the classifier on the test set
preds = clf.predict(embeddings['test'])
print(classification_report(labels['test'], preds))
Training a deep model from scratch¶
The second approach requires training a deep neural network from scratch. To this end, we can use the TorchGeo’s ClassificationTask
and Lightning’s Trainer
to simplify the training.
Remember to set the in_channels
parameter to 4 since we are concatenating two (pre and post-event) two-channel images.
[ ]:
task = ClassificationTask(
model='resnet18', in_channels=4, task='binary', loss='bce', lr=0.0001
)
[ ]:
# Set up the trainer logger and checkpoint callback
default_root_dir = Path(tempfile.gettempdir()) / 'experiments'
checkpoint_callback = ModelCheckpoint(
monitor='val_loss', dirpath=default_root_dir, save_top_k=1, save_last=True
)
logger = CSVLogger(save_dir=default_root_dir, name='tutorial_logs')
# Set up the trainer
trainer = Trainer(
accelerator='auto',
callbacks=[checkpoint_callback],
log_every_n_steps=1,
logger=logger,
max_epochs=max_epochs,
limit_train_batches=PCT_SAMPLES,
limit_val_batches=PCT_SAMPLES,
limit_test_batches=PCT_SAMPLES,
fast_dev_run=fast_dev_run,
)
[ ]:
# Fit the model
trainer.fit(model=task, datamodule=datamodule)
[ ]:
# Test the model and print the results
trainer.test(model=task, datamodule=datamodule)
This tutorial shows how to use the TorchGeo pretrained models in place of full training from scratch. You can see a full list of available models in the TorchGeo documentation.
For a review of applications of deep learning in earthquake engineering see https://arxiv.org/abs/2405.09021