Foundation Models (Signal-JEPA)

Tutorial for loading and fine-tuning pre-trained foundation models with Braindecode

Foundation models are large models pretrained on broad, diverse data with self-supervised objectives to learn general-purpose representations. After pretraining they can be adapted to many downstream tasks (fine-tuning, feature extraction or linear probing), often reducing the need for labeled data and improving robustness across tasks.

Signal-JEPA follows this paradigm and introduces a novel approach for self-supervised learning of EEG signals. The model is pre-trained to predict latent representations of future signal segments based on past signal segments, using a joint-embedding architecture. This approach allows the model to learn meaningful representations of EEG data without requiring labeled data.

Here is a minimal example that loads the pre-trained Signal-JEPA weights in a downstream architecture:

import torch
from braindecode.models import SignalJEPA_PreLocal

# 1. Download the pre-trained weights:
state_dict = torch.hub.load_state_dict_from_url(
    url="https://huggingface.co/braindecode/SignalJEPA/resolve/main/signal-jepa_16s-60_adeuwv4s.pth"
)

# 2. Instantiate the model:
SignalJEPA_PreLocal(
    sfreq=128.0,
    input_window_seconds=2.0,  # Adjust according to your data
    n_chans=16,                # Adjust according to your data
    n_outputs=4,               # Adjust according to your data
)

# 3. Load the weights into the model:
model.load_state_dict(state_dict, strict=False)

More details can be found in full article, presented at the 9th Graz Brain-Computer Interface Conference (reference bellow).

And the tutorial linked above explains in details the full process of loading a pre-trained model and fine-tuning it on a downstream task.


References