{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Copyright (c) TorchGeo Contributors. All rights reserved.\n", "# Licensed under the MIT License." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Earthquake Prediction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "*Written by Daniele Rege Cambrin*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Introduction" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The objective of this tutorial is to go through the QuakeSet dataset and cover the following topics:\n", "\n", "* How to use TorchGeo data modules to load datasets and plot samples;\n", "* How to use TorchGeo pre-trained model embeddings to train a classical model (e.g., Random Forest);\n", "* How to train a new TorchGeo deep model using tasks and trainer;" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Environment" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For the environment, we will install the torchgeo, h5py, and scikit-learn packages." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install torchgeo h5py scikit-learn" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tempfile\n", "from collections import defaultdict\n", "from pathlib import Path\n", "\n", "import numpy as np\n", "import torch\n", "from lightning.pytorch import Trainer\n", "from lightning.pytorch.callbacks import ModelCheckpoint\n", "from lightning.pytorch.loggers import CSVLogger\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.metrics import classification_report\n", "from tqdm import tqdm\n", "\n", "from torchgeo.datamodules import QuakeSetDataModule\n", "from torchgeo.models import ResNet50_Weights, resnet50\n", "from torchgeo.trainers import ClassificationTask" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We will use the [QuakeSet dataset](https://doi.org/10.59297/n89yc374) (licensed under OpenRAIL License), which has patches from around the world before and after an earthquake, with corresponding negative examples.\n", "\n", "The dataset uses SAR imagery from Sentinel-1 satellite with 10m of spatial resolution. The task is to predict for each couple of images if an earthquake occurs between them." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "nbmake": { "mock": { "batch_size": 1, "fast_dev_run": true, "max_epochs": 1, "num_workers": 0 } } }, "outputs": [], "source": [ "num_workers = 1\n", "batch_size = 4\n", "max_epochs = 10\n", "fast_dev_run = False" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "nbmake": { "mock": { "datamodule.dataset_class.filename": "earthquakes_sample.h5", "datamodule.dataset_class.url": "https://hf.co/datasets/DarthReca/quakeset/resolve/4521cf03c4bb67084c0460412785722646f2ec9d/earthquakes_sample.h5" } } }, "outputs": [], "source": [ "tmp_path = Path(tempfile.gettempdir())\n", "# The data module has already been implemented in TorchGeo, so we can use it\n", "datamodule = QuakeSetDataModule(\n", " batch_size=batch_size, num_workers=num_workers, download=True, root=tmp_path\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# This will download the dataset\n", "datamodule.prepare_data()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Batch Visualization" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The QuakeSetDataModule already has a `plot` function implemented to show the samples.\n", "\n", "Remember to call `setup` before using it with *fit* or *test* (otherwise, you will get an error)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "datamodule.setup('fit')\n", "datamodule.setup('test')\n", "sample = datamodule.val_dataset[0]\n", "datamodule.plot(sample)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train ML model on Pretrained embeddings" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The first approach uses a pre-trained deep-learning model to compute the embeddings to train a classical machine learning model.\n", "\n", "First, we have set the constants to select which device to use and the dataset percentage to use (default 10%)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Change to \"cpu\" if you don't have a GPU\n", "DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "# Percentage of samples to use\n", "PCT_SAMPLES = 0.1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# We take a subset of the dataset to speed up training\n", "datamodule.train_dataset.data = datamodule.train_dataset.data[\n", " : int(len(datamodule.train_dataset.data) * PCT_SAMPLES)\n", "]\n", "datamodule.val_dataset.data = datamodule.val_dataset.data[\n", " : int(len(datamodule.val_dataset.data) * PCT_SAMPLES)\n", "]\n", "datamodule.test_dataset.data = datamodule.test_dataset.data[\n", " : int(len(datamodule.test_dataset.data) * PCT_SAMPLES)\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we load a ResNet50 pre-trained on Sentinel-1 images and define the function to make inferences." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model_transform = ResNet50_Weights.SENTINEL1_ALL_MOCO.transforms\n", "rn_model = resnet50(ResNet50_Weights.SENTINEL1_ALL_MOCO).to(DEVICE).eval()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def infer(batch):\n", " img = batch['image'].to(DEVICE)\n", " labels = batch['label']\n", " # Each image has 4 channels (two channels for pre-event image, and two for post-event).\n", " # We need to split it into two images with two channels each.\n", " pre = model_transform({'image': img[:, :2]})['image']\n", " post = model_transform({'image': img[:, 2:]})['image']\n", " with torch.no_grad():\n", " embs = torch.concat([rn_model(pre), rn_model(post)], axis=1).cpu().numpy()\n", " return embs, labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we run inference on the training and test sets to compute the embeddings. The model will be used as a feature extractor." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "embeddings = defaultdict(list)\n", "labels = defaultdict(list)\n", "# We iterate over the train_dataloader\n", "for i, batch in tqdm(enumerate(datamodule.train_dataloader()), desc='Train'):\n", " for j, (emb, lab) in enumerate(zip(*infer(batch))):\n", " embeddings['train'].append(emb)\n", " labels['train'].append(lab)\n", "# We iterate over the test_dataloader\n", "for i, batch in tqdm(enumerate(datamodule.test_dataloader()), desc='Test'):\n", " for j, (emb, lab) in enumerate(zip(*infer(batch))):\n", " embeddings['test'].append(emb)\n", " labels['test'].append(lab)\n", "# Now we merge the embeddings and labels into a single array\n", "embeddings = {k: np.stack(v) for k, v in embeddings.items()}\n", "labels = {k: np.array(v) for k, v in labels.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now, we can fit a classical model (e.g., Random Forest) using the embeddings as features and the labels as targets." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Train a RandomForest classifier\n", "clf = RandomForestClassifier(n_estimators=100, n_jobs=-1)\n", "clf.fit(embeddings['train'], labels['train'])\n", "# Evaluate the classifier on the test set\n", "preds = clf.predict(embeddings['test'])\n", "print(classification_report(labels['test'], preds))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Training a deep model from scratch" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The second approach requires training a deep neural network from scratch. To this end, we can use the TorchGeo's `ClassificationTask` and Lightning's `Trainer` to simplify the training.\n", "\n", "Remember to set the `in_channels` parameter to 4 since we are concatenating two (pre and post-event) two-channel images." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "task = ClassificationTask(\n", " model='resnet18', in_channels=4, task='binary', loss='bce', lr=0.0001\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Set up the trainer logger and checkpoint callback\n", "default_root_dir = Path(tempfile.gettempdir()) / 'experiments'\n", "checkpoint_callback = ModelCheckpoint(\n", " monitor='val_loss', dirpath=default_root_dir, save_top_k=1, save_last=True\n", ")\n", "logger = CSVLogger(save_dir=default_root_dir, name='tutorial_logs')\n", "# Set up the trainer\n", "trainer = Trainer(\n", " accelerator='auto',\n", " callbacks=[checkpoint_callback],\n", " log_every_n_steps=1,\n", " logger=logger,\n", " max_epochs=max_epochs,\n", " limit_train_batches=PCT_SAMPLES,\n", " limit_val_batches=PCT_SAMPLES,\n", " limit_test_batches=PCT_SAMPLES,\n", " fast_dev_run=fast_dev_run,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Fit the model\n", "trainer.fit(model=task, datamodule=datamodule)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Test the model and print the results\n", "trainer.test(model=task, datamodule=datamodule)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This tutorial shows how to use the TorchGeo pretrained models in place of full training from scratch. You can see a full list of available models in the [TorchGeo documentation](https://torchgeo.readthedocs.io/en/latest/api/models.html#pretrained-weights).\n", "\n", "For a review of applications of deep learning in earthquake engineering see https://arxiv.org/abs/2405.09021" ] } ], "metadata": { "accelerator": "GPU", "kernelspec": { "display_name": "torchgeo_dev", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } }, "nbformat": 4, "nbformat_minor": 0 }