Foundation Models (Signal-JEPA)
Tutorial for loading and fine-tuning pre-trained foundation models with Braindecode
Foundation models are large models pretrained on broad, diverse data with self-supervised objectives to learn general-purpose representations. After pretraining they can be adapted to many downstream tasks (fine-tuning, feature extraction or linear probing), often reducing the need for labeled data and improving robustness across tasks.
Signal-JEPA follows this paradigm and introduces a novel approach for self-supervised learning of EEG signals. The model is pre-trained to predict latent representations of future signal segments based on past signal segments, using a joint-embedding architecture. This approach allows the model to learn meaningful representations of EEG data without requiring labeled data.
Here is a minimal example that loads the pre-trained Signal-JEPA weights in a downstream architecture:
import torch
from braindecode.models import SignalJEPA_PreLocal
# 1. Download the pre-trained weights:
state_dict = torch.hub.load_state_dict_from_url(
url="https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth"
)
# 2. Instantiate the model:
SignalJEPA_PreLocal(
sfreq=128.0,
input_window_seconds=2.0, # Adjust according to your data
n_chans=16, # Adjust according to your data
n_outputs=4, # Adjust according to your data
)
# 3. Load the weights into the model:
model.load_state_dict(state_dict, strict=False)
More details can be found in full article, presented at the 9th Graz Brain-Computer Interface Conference (reference bellow).
And the tutorial linked above explains in details the full process of loading a pre-trained model and fine-tuning it on a downstream task.