GPT 风格 Transformer 解码 EEG:为什么、怎么做、踩了什么坑
为什么用 Transformer 解码 EEG
EEG 解码的传统方法是 LDA + CSP:手工提取特征,再喂给线性分类器。这条路成熟但天花板明显——特征工程依赖领域知识,换一个范式就得重来。
深度学习跳过了特征工程:CNN 直接从原始信号学特征。但 CNN 有自己的局限——它的卷积核是局部的,对 EEG 这种时序信号的长程依赖建模能力弱。
Transformer 的自注意力机制天然擅长捕捉长程依赖。NLP 领域已经证明了这一点。问题是:把 NLP 的 Transformer 搬到 EEG 上,哪些设计要改?
答案是几乎全要改。
EEG 不是 NLP:三个关键差异
| NLP | EEG | |
|---|---|---|
| 输入 | 离散 token(词表索引) | 连续多通道时序(float32 数组) |
| 位置信息 | 词序(离散位置) | 采样时间点(连续,长度可变) |
| 预测目标 | 下一个词(自回归生成) | 整段信号属于哪类(分类) |
这三个差异决定了架构的每一个选择。
架构设计
1 | 原始 EEG (n_channels, n_times) |
1. Token Embedding:不用词表,用卷积
NLP 有词表,每个词查表得到 embedding。EEG 没有词表——输入是 (n_channels, n_times) 的连续信号。
做法:用 Conv1d(n_channels, n_channels, kernel, stride) 把时域信号切成 token 序列。每个 token 覆盖 kernel 个时间点,步长 stride,输出的序列长度就是 n_tokens = (n_times - kernel) // stride + 1。
1 | class _TokenEmbedding(nn.Module): |
为什么用 Conv1d 而不是固定分帧:信号处理里"分帧"是按固定窗口切信号、每帧取均值或原始采样点——这是不可学习的。Conv1d 做的事本质相同(kernel=窗口,stride=步长),但权重可学习:kernel 内每个时间点的贡献由训练决定,比分帧平均更有表达力。
kernel/stride 怎么选:如果用户不指定,自动算。目标是让 n_tokens ≈ 128,50% 重叠(kernel = 2 * stride):
1 | stride = max(1, n_times // target_tokens) |
128 个 token 足够 Transformer 捕捉时序依赖,又不会让 O(N²) 的注意力计算爆炸。
2. RoPE:让位置编码外推
NLP 用固定的正弦位置编码或可学习的位置编码。问题:可学习位置编码只能编码训练时见过的长度,推理时遇到更长序列就废了。
RoPE(Rotary Positional Embedding)是更好的选择:它把位置信息编码成 Q/K 的旋转角度,天然支持长度外推——训练时见 128 个 token,推理时给 200 个也能跑。
1 | class _RotaryPositionalEmbedding(nn.Module): |
核心思想:位置 m 的 Q/K 不加偏置,而是旋转。_rotate_half 把特征维的相邻两个值当作一对,做 2D 旋转。内积 Q·K 自然包含相对位置信息 m-n,不需要显式的位置编码表。
3. 因果注意力:只看过去,不看未来
这是 GPT 风格 Transformer 的标志性设计。标准 BERT 式自注意力每个位置能看所有位置,但 EEG 解码如果用双向注意力,模型会"偷看"未来的脑电信号——这在实时 BCI 场景下是不可能的。
因果掩码:上三角为 -inf,softmax 后变成 0:
1 | class _CausalMask(nn.Module): |
在注意力分数上加这个掩码,位置 t 只能 attend 到 0…t。
4. Pre-LN + 残差:稳定训练
1 | class _DecoderBlock(nn.Module): |
Pre-LN(先 LayerNorm 再进子层)比 Post-LN(先子层再 LayerNorm)训练更稳定。这是 GPT-2 以后的标准做法。FFN 扩张比 4×,用 GELU 激活。
5. 分类头:取最后一个位置
BERT 取 [CLS] token 的输出做分类。GPT 没有 CLS,那取哪个位置?
取最后一个:x[:, -1, :]。因为因果注意力下,最后一个位置积累了前面所有位置的信息——它"看"了整段信号。
6. 长度自适应推理
推理时 EEG 长度可能和训练时不一样。因为用了 Conv1d token embedding + RoPE,模型天然支持任意长度输入(只要 n_times >= kernel)。如果推理时 token 数超过训练时,RoPE 的位置会外推,给一个警告:
1 | if n_tokens > self._train_n_tokens: |
7. 逐通道 z-score 归一化
EEG 通道间的幅度差异很大(额区 vs 枕区),不归一化的话梯度会被大信号通道主导。训练时计算每个通道的均值/标准差,推理时用同样的统计量:
1 | if self.normalize: |
均值和标准差随模型一起保存,推理时自动应用。
和 CNN 解码器的对比
同项目里还有一个 CNN 解码器作为对照:
1 | # CNN: 输入 (B, 1, n_ch, n_times),2D 卷积,全连接分类 |
| CNN | Transformer | |
|---|---|---|
| 输入视角 | 把 EEG 当"图像"(1×n_ch×n_times) | 把 EEG 当"时序"(n_tokens×d_model) |
| 感受野 | 局部(卷积核 3) | 全局(自注意力看所有 token) |
| 长度灵活性 | 固定(全连接层绑死 n_times) | 自适应(RoPE 外推) |
| 训练开销 | 小(参数少,全批量) | 大(O(N²) 注意力,需 mini-batch) |
CNN 快但天花板低,Transformer 慢但上限高。实际选择看数据量和计算资源。
训练细节
- 优化器:AdamW(lr=5e-4, weight_decay=1e-4)
- 损失:CrossEntropyLoss
- 不用 DataLoader——数据直接放 GPU,手动 shuffle index 切 mini-batch,省掉 dataloader 的开销
- 默认 50 epoch,batch_size=32
踩过的坑
- EEG 数据量太小:BCI 数据集通常几十到几百个 trial,Transformer 容易过拟合。数据增强(高斯噪声、时间平移)很重要
- RoPE 外推不是免费的:推理时 token 数超过训练时的 1.5 倍以上,准确率会明显下降。不是"支持任意长度"就真的随便用
- d_model 不要太大:EEG 通道数通常 32-64,d_model=64 已经够用。512 是 NLP 的配置,在 EEG 上只会过拟合
- 因果注意力不一定比双向好:离线分析场景下,双向注意力允许"偷看未来",准确率更高。因果注意力的价值在实时 BCI——只有实时场景才真正需要"不能看未来"的约束
总结
把 Transformer 从 NLP 搬到 EEG,不是换个输入格式就完了。Token 怎么构建、位置怎么编码、注意力要不要因果、分类头取哪个位置——每个设计选择都需要根据 EEG 的特性重新思考。
这个实现不是 SOTA(BCI 领域的 SOTA 是 EEGNet + 数据增强),但它完整地回答了一个问题:如果要从零实现一个 EEG Transformer,需要做哪些设计决策,以及为什么。
