为什么用 Transformer 解码 EEG

EEG 解码的传统方法是 LDA + CSP:手工提取特征,再喂给线性分类器。这条路成熟但天花板明显——特征工程依赖领域知识,换一个范式就得重来。

深度学习跳过了特征工程:CNN 直接从原始信号学特征。但 CNN 有自己的局限——它的卷积核是局部的,对 EEG 这种时序信号的长程依赖建模能力弱。

Transformer 的自注意力机制天然擅长捕捉长程依赖。NLP 领域已经证明了这一点。问题是:把 NLP 的 Transformer 搬到 EEG 上,哪些设计要改?

答案是几乎全要改。

EEG 不是 NLP:三个关键差异

NLP EEG
输入 离散 token(词表索引) 连续多通道时序(float32 数组)
位置信息 词序(离散位置) 采样时间点(连续,长度可变)
预测目标 下一个词(自回归生成) 整段信号属于哪类(分类)

这三个差异决定了架构的每一个选择。

架构设计

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
原始 EEG (n_channels, n_times)


Conv1d Token Embedding ──→ (n_tokens, n_channels)


Linear Projection ──→ (n_tokens, d_model)


N × Decoder Block
┌─────────────────────┐
│ LayerNorm │
│ Causal MHA + RoPE │
│ Residual │
│ LayerNorm │
│ FFN (4× expansion) │
│ Residual │
└─────────────────────┘


取最后位置 → Linear → n_classes

1. Token Embedding:不用词表,用卷积

NLP 有词表,每个词查表得到 embedding。EEG 没有词表——输入是 (n_channels, n_times) 的连续信号。

做法:用 Conv1d(n_channels, n_channels, kernel, stride) 把时域信号切成 token 序列。每个 token 覆盖 kernel 个时间点,步长 stride,输出的序列长度就是 n_tokens = (n_times - kernel) // stride + 1

1
2
3
4
5
6
7
8
9
class _TokenEmbedding(nn.Module):
def __init__(self, n_channels, kernel, stride):
super().__init__()
self.conv = nn.Conv1d(n_channels, n_channels,
kernel_size=kernel, stride=stride)

def forward(self, x):
# x: (B, n_ch, n_times) → (B, n_tokens, n_ch)
return self.conv(x).transpose(1, 2)

为什么用 Conv1d 而不是固定分帧:信号处理里"分帧"是按固定窗口切信号、每帧取均值或原始采样点——这是不可学习的。Conv1d 做的事本质相同(kernel=窗口,stride=步长),但权重可学习:kernel 内每个时间点的贡献由训练决定,比分帧平均更有表达力。

kernel/stride 怎么选:如果用户不指定,自动算。目标是让 n_tokens ≈ 128,50% 重叠(kernel = 2 * stride):

1
2
stride = max(1, n_times // target_tokens)
kernel = stride * 2

128 个 token 足够 Transformer 捕捉时序依赖,又不会让 O(N²) 的注意力计算爆炸。

2. RoPE:让位置编码外推

NLP 用固定的正弦位置编码或可学习的位置编码。问题:可学习位置编码只能编码训练时见过的长度,推理时遇到更长序列就废了。

RoPE(Rotary Positional Embedding)是更好的选择:它把位置信息编码成 Q/K 的旋转角度,天然支持长度外推——训练时见 128 个 token,推理时给 200 个也能跑。

1
2
3
4
5
6
7
8
9
10
11
class _RotaryPositionalEmbedding(nn.Module):
def __init__(self, d_model, max_seq_len=4096, base=10000.0):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, d_model, 2).float() / d_model))
self.register_buffer("inv_freq", inv_freq, persistent=False)

def forward(self, x, positions):
freqs = torch.outer(positions.float(), self.inv_freq)
cos = freqs.cos().repeat_interleave(2, dim=-1)
sin = freqs.sin().repeat_interleave(2, dim=-1)
return x * cos + self._rotate_half(x) * sin

核心思想:位置 m 的 Q/K 不加偏置,而是旋转。_rotate_half 把特征维的相邻两个值当作一对,做 2D 旋转。内积 Q·K 自然包含相对位置信息 m-n,不需要显式的位置编码表。

3. 因果注意力:只看过去,不看未来

这是 GPT 风格 Transformer 的标志性设计。标准 BERT 式自注意力每个位置能看所有位置,但 EEG 解码如果用双向注意力,模型会"偷看"未来的脑电信号——这在实时 BCI 场景下是不可能的。

因果掩码:上三角为 -inf,softmax 后变成 0:

1
2
3
4
5
class _CausalMask(nn.Module):
def __init__(self, size):
super().__init__()
mask = torch.triu(torch.full((size, size), float("-inf")), diagonal=1)
self.register_buffer("mask", mask, persistent=False)

在注意力分数上加这个掩码,位置 t 只能 attend 到 0…t。

4. Pre-LN + 残差:稳定训练

1
2
3
4
5
class _DecoderBlock(nn.Module):
def forward(self, x):
x = x + self.attn(self.ln1(x)) # Pre-LN: 先归一化再进注意力
x = x + self.ffn(self.ln2(x)) # 同理
return x

Pre-LN(先 LayerNorm 再进子层)比 Post-LN(先子层再 LayerNorm)训练更稳定。这是 GPT-2 以后的标准做法。FFN 扩张比 4×,用 GELU 激活。

5. 分类头:取最后一个位置

BERT 取 [CLS] token 的输出做分类。GPT 没有 CLS,那取哪个位置?

取最后一个:x[:, -1, :]。因为因果注意力下,最后一个位置积累了前面所有位置的信息——它"看"了整段信号。

6. 长度自适应推理

推理时 EEG 长度可能和训练时不一样。因为用了 Conv1d token embedding + RoPE,模型天然支持任意长度输入(只要 n_times >= kernel)。如果推理时 token 数超过训练时,RoPE 的位置会外推,给一个警告:

1
2
if n_tokens > self._train_n_tokens:
warnings.warn("RoPE position extrapolation; accuracy may degrade")

7. 逐通道 z-score 归一化

EEG 通道间的幅度差异很大(额区 vs 枕区),不归一化的话梯度会被大信号通道主导。训练时计算每个通道的均值/标准差,推理时用同样的统计量:

1
2
3
4
5
if self.normalize:
self._mean = X.mean(axis=(0, 2), keepdims=True)
self._std = X.std(axis=(0, 2), keepdims=True)
self._std = np.where(self._std < 1e-8, 1.0, self._std) # 零方差通道保护
X = (X - self._mean) / self._std

均值和标准差随模型一起保存,推理时自动应用。

和 CNN 解码器的对比

同项目里还有一个 CNN 解码器作为对照:

1
2
3
4
5
6
# CNN: 输入 (B, 1, n_ch, n_times),2D 卷积,全连接分类
class _EEGCNN(nn.Module):
def __init__(self, n_channels, n_times, n_classes, dropout=0.25):
self.conv1 = nn.Conv2d(1, 16, kernel_size=(n_channels, 3))
self.conv2 = nn.Conv2d(16, 32, kernel_size=(1, 3))
self.fc = nn.Linear(flatten_size, n_classes)
CNN Transformer
输入视角 把 EEG 当"图像"(1×n_ch×n_times) 把 EEG 当"时序"(n_tokens×d_model)
感受野 局部(卷积核 3) 全局(自注意力看所有 token)
长度灵活性 固定(全连接层绑死 n_times) 自适应(RoPE 外推)
训练开销 小(参数少,全批量) 大(O(N²) 注意力,需 mini-batch)

CNN 快但天花板低,Transformer 慢但上限高。实际选择看数据量和计算资源。

训练细节

  • 优化器:AdamW(lr=5e-4, weight_decay=1e-4)
  • 损失:CrossEntropyLoss
  • 不用 DataLoader——数据直接放 GPU,手动 shuffle index 切 mini-batch,省掉 dataloader 的开销
  • 默认 50 epoch,batch_size=32

踩过的坑

  1. EEG 数据量太小:BCI 数据集通常几十到几百个 trial,Transformer 容易过拟合。数据增强(高斯噪声、时间平移)很重要
  2. RoPE 外推不是免费的:推理时 token 数超过训练时的 1.5 倍以上,准确率会明显下降。不是"支持任意长度"就真的随便用
  3. d_model 不要太大:EEG 通道数通常 32-64,d_model=64 已经够用。512 是 NLP 的配置,在 EEG 上只会过拟合
  4. 因果注意力不一定比双向好:离线分析场景下,双向注意力允许"偷看未来",准确率更高。因果注意力的价值在实时 BCI——只有实时场景才真正需要"不能看未来"的约束

总结

把 Transformer 从 NLP 搬到 EEG,不是换个输入格式就完了。Token 怎么构建、位置怎么编码、注意力要不要因果、分类头取哪个位置——每个设计选择都需要根据 EEG 的特性重新思考。

这个实现不是 SOTA(BCI 领域的 SOTA 是 EEGNet + 数据增强),但它完整地回答了一个问题:如果要从零实现一个 EEG Transformer,需要做哪些设计决策,以及为什么。