2021 年,William Fedus、Barret Zoph、Noam Shazeer(Google)发表 Switch Transformer——MoE 谱系第三个里程碑。核心想法只有一个——把 MoE 路由从 top-k 简化到 top-1。Shazeer 2017 说"必须 $k > 1$ 才能有有意义的梯度"——Switch 论文反驳这个假设,证明 k=1 不仅可行而且更好。简化路由 + 选择性 bfloat16 精度 + 更小初始化 + 专家正则化——让稀疏模型规模到万亿参数,比 T5-XXL 快 4 倍,在 101 个语言上全面提升。蒸馏把万亿参数稀疏模型压缩 99%到 dense 模型,保留 30% 质量增益。
Kaplan 2020 的 scaling law 揭示了三个扩展轴——模型大小、数据集大小、计算预算。Switch Transformer 提出第四个轴——"在保持每样本 FLOPs 不变的同时增加参数数量"。假设是——参数数量本身就是独立的重要扩展维度。Switch 通过"稀疏激活的模型"实现——它高效利用 GPU/TPU 这种为稠密矩阵乘法设计的硬件。分布式训练里——稀疏激活的层把唯一权重分散到不同设备。模型权重随设备数增加,但每设备的内存和计算占用保持可管理。
2021 年 1 月(arXiv 提交日期),William Fedus、Barret Zoph、Noam Shazeer——同一个 Noam Shazeer,2017 年是Sparsely-Gated MoE的一作,也是Transformer的核心作者——在 Google 发表《Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity》。
论文摘要直接——
深度学习里,模型对所有输入复用同样参数。
MoE 模型违背这一点——
它给每个输入选不同参数——
结果是一个稀疏激活的、参数离谱地多的、计算成本恒定的模型。
但尽管 MoE 有几个显著成功,
广泛采用被复杂性、通信成本、训练不稳定性阻碍。
我们引入 Switch Transformer 解决这些。
Switch 的指导原则——"用简单和计算高效的方式最大化 Transformer 参数数"。
Kaplan 2020 的 scaling law 揭示了三个扩展轴——模型大小、数据集大小、计算预算。Switch 提出第四个轴——
我们探索第四个轴——
在保持每样本 FLOPs 恒定的同时
增加参数数量。
我们的假设是——
参数数量,独立于总计算——
是一个独立的重要扩展维度。
Switch 通过稀疏激活模型实现——"高效利用为稠密矩阵乘法设计的 GPU/TPU 硬件"。论文的核心贡献——
· Switch Transformer 架构——简化并改进 MoE
· vs T5 的扩展性质——7x+ 预训练加速,同样 FLOPS/token
· 蒸馏成功——99% 压缩 + 30% 质量增益保留
· 改进训练技巧——选择性 bfloat16 精度、更小初始化、更强专家正则
· 多语言全面提升——101 个语言,91% 的语言在 mT5 上有 4x+ 加速
· 万亿参数模型——T5-XXL 强基线上4 倍加速
Shazeer 2017 主张必须 $k > 1$——"学习路由如果没有比较至少两个专家的能力就不能工作"。Ramachandran-Le 2018 进一步研究 top-k——发现底层用更大 k 重要。Switch 反驳——"我们用一个简化策略,只路由到一个专家"。三大好处——(1) 路由计算减少;(2) 每个专家的批量至少减半;(3) 路由实现简化、通信成本降低。这是反直觉的洞察——MoE 越简单越好。
Shazeer 2017 的 MoE 路由——给定 token $x$,路由到 top-k 专家。router 变量 $W_r$ 产生 logits $h(x) = W_r \cdot x$,过 softmax 归一化——
其中 $T$ 是选中的 top-k 索引集合。Shazeer 2017 主张——"路由到 $k > 1$ 专家是必要的——为了有意义的梯度"。直觉——学习路由如果没有比较至少两个专家的能力就不能工作。Ramachandran-Le 2018 进一步研究——发现底层用更大 k 重要。
Switch 论文反驳这两个观察——
相反——我们用简化策略——
只路由到一个专家。
我们证明这个简化保留模型质量,
减少路由计算,表现更好。
这个 $k = 1$ 路由策略被称为 Switch 层。
"Switch" 这个名字的含义——路由器只是"开关"——把 token 切换到一个专家。
Switch 层的三个好处——
注意——MoE 和 Switch 路由都用 $p_i(x)$ 作为乘子(Eq. 2 里的 gate value)——让路由器可微。这是关键——梯度仍然能流过路由器。
这是 Switch 最美的地方——反直觉,但简单战胜复杂。MoE 越简单越好——这是 Switch 给出的哲学。
TPU 需要静态声明大小——所以 Switch 的张量形状编译时确定,但路由动态。专家容量解决这个矛盾——把 token 数平均分给专家,再乘容量因子(capacity factor)——比 1.0 大留缓冲。如果某专家收到的 token 超过容量——"dropped token"——计算被跳过,token 表示通过残差连接直接传到下一层。增大容量因子减少 token drop,但浪费计算和内存。Switch 论文发现——dropped token 率低于 1%,不依赖专家数量。
Switch 论文用 Mesh-TensorFlow——为分布式数据 + 模型并行设计的库。模型为 TPU 设计——需要静态声明大小。
但 Switch 路由动态——不同 batch 不同 token 会路由到不同专家。这两件事如何统一?
解法——专家容量——
专家容量是每个专家计算的 token 数——预先设定。容量因子大于 1.0 留缓冲,应付 token 分配不完美平衡。
如果一个专家收到的 token 超过容量——这些 token 被称为 "dropped tokens"——计算被跳过,token 表示通过残差连接直接传到下一层。
容量因子的权衡——
· 太小——很多 token 被 drop——模型质量下降
· 太大——浪费计算和内存(很多缓冲空槽)
Switch 论文发现——用负载平衡损失(下一章)配合足够高的系数,dropped token 率通常 < 1%,并且不依赖专家数——这意味着负载平衡有效。
Switch 简化了 Shazeer 2017 的"两个分离损失"(importance + load)——合成一个辅助损失。给定 $N$ 个专家、$T$ 个 token 的批量——损失是"实际分配比例 $f_i$"和"路由概率比例 $P_i$"的缩放点积。$f_i$——分到专家 $i$ 的 token 比例(不可微)。$P_i$——批内分给专家 $i$ 的路由概率比例(可微)。两者都希望 = $1/N$(均匀路由)。损失在均匀分布下最小化。$P$ 可微所以梯度能流回。系数 $\alpha = 10^{-2}$——"既保证平衡又不淹没主损失"。
Switch 简化了 Shazeer 2017 的原始设计——把负载平衡损失和importance weighting 损失合成一个辅助损失。
给定 $N$ 个专家(索引 $i = 1$ 到 $N$)、批量 $B$ 含 $T$ 个 token——辅助损失为缩放点积——
其中——
· $f_i$ 是实际分配到专家 $i$ 的 token 比例——
· $P_i$ 是批内分给专家 $i$ 的路由概率比例——
我们想要批内 token 均匀分到 $N$ 个专家——所以希望两个向量都 = $1/N$。Eq. 4 的辅助损失鼓励均匀路由——它在均匀分布下最小化。
注意——$P$ 可微,$f$ 不可微。但 $f \cdot P$ 通过 $P$ 可以求导——梯度能流回路由器。
损失乘以专家数 $N$ ——保持损失值随专家数变化而恒定。最后——超参 $\alpha$ 是辅助损失系数——Switch 论文用 $\alpha = 10^{-2}$——"足够大保证负载平衡,足够小不淹没主交叉熵目标"。
稀疏专家模型有训练不稳定性问题——硬切换决策让 softmax 计算对 bfloat16 精度敏感。GShard(Lepikhin 2020)不得不用 float32 训练——但通信成本巨大。Switch 论文展示——在模型的"局部区域选择性 cast 到 float32" 就能保持稳定性,不付出跨设备通信 float32 张量的代价。具体做法——把路由器输入 cast 到 float32,路由计算在 float32 里做,函数末尾把 dispatch/combine 张量 cast 回 bfloat16。结果——近乎 bfloat16 的速度 + float32 的训练稳定性。
稀疏专家模型可能引入普通 Transformer 没有的训练困难——硬切换路由决策每层都做,可能造成不稳定。低精度格式 bfloat16 在路由 softmax 计算里会加剧问题。
GShard(Lepikhin 2020)的解法——整个 MoE Transformer 都用 float32 训练。但 float32 比 bfloat16慢、占内存大、跨设备通信成本高。
Switch 的解法——选择性精度——
我们展示——
通过在模型的局部部分选择性 cast 到 float32 精度,
稳定性可以达成——
而不付出float32 张量跨设备通信的昂贵成本。
具体做法——把路由器输入 cast 到 float32。路由函数接收 token,产生"dispatch 和 combine 张量"用于专家计算的选择和重组。float32 精度只在路由函数内部使用——在本地设备上的计算。
因为函数末尾,dispatch 和 combine 张量被cast 回 bfloat16——不需要昂贵的 float32 张量跨设备通信。
结果——"接近 bfloat16 的速度 + float32 的训练稳定性"。
这是 Switch 的关键工程贡献——稀疏模型第一次可以稳定地用 bfloat16 训练。这解锁了万亿参数——因为float32 训练万亿参数太贵。
Switch 在 C4(Colossal Clean Crawled Corpus)上做 mask LM 预训练。Switch vs T5(dense)头对头比较——同样 FLOPs/token、同样硬件、同样训练步数。三大发现——(1) Switch 在速度-质量上同时打败 dense 和 MoE Transformer;(2) Switch 计算占用比 MoE 小;(3) Switch 在低容量因子(1.0, 1.25)下表现更好——更适合大模型内存稀缺的场景。Switch-Base 1.0 容量因子达到质量阈值用 62.8 小时,T5-Large 用 131.1 小时——2 倍快。在 mT5 baseline 上 91% 的 101 个语言得到4x+ 加速。
Switch Transformer 的第一个测试——在 C4(Colossal Clean Crawled Corpus) 上做mask LM 预训练。用负对数 perplexity作为指标。
Switch vs MoE Transformer 头对头比较 (Table 1)——
| 模型 | 容量因子 | 100k 步质量 ↑ | 到质量阈值 -1.50 时间 ↓ | 速度 |
|---|---|---|---|---|
| T5-Base | — | -1.731 | 未达到 | 1600 |
| T5-Large | — | -1.550 | 131.1 h | 470 |
| MoE-Base | 1.0 | -1.572 | 80.1 h | 860 |
| Switch-Base | 1.0 | -1.561 | 62.8 h | 1000 |
| Switch-Base+ | 1.0 | -1.534 | 67.6 h | 780 |
三大发现——
(1) Switch 在速度-质量上同时打败 dense 和 MoE——固定计算和时间,Switch 取得最好结果。
(2) Switch 计算占用比 MoE 小——把 Switch 扩大到匹配 MoE 训练速度时(Switch-Base+),每步基础上也击败所有 MoE 和 Dense。
(3) Switch 在低容量因子(1.0, 1.25)下更好——更适合大模型内存稀缺的场景。
论文还测了多语言学习——在 101 个语言上和 mT5-Base 对比。所有 101 个语言都得到普遍提升。91% 的语言在 mT5 baseline 上得到 4x+ 加速。
Switch 论文做了一件让我读完后惊讶的事——把稀疏预训练 + 专门化微调的模型成功蒸馏到小 dense 模型。"模型大小减少最多 99%,同时保留大稀疏教师 30% 的质量增益"。这是"训练用 MoE,推理用 dense"的有趣范式——训练时用稀疏的便宜容量,推理时用 dense 的部署简便。这给边缘部署、移动设备开了一条路。
Switch 论文做了一件让我特别注意的事——蒸馏。
大稀疏模型有部署难题——万亿参数不适合每台设备都装一份。蒸馏可以解决——
我们成功蒸馏稀疏预训练模型和专门化微调模型——
到小 dense 模型。
我们把模型大小最多减少 99%——
同时保留大稀疏教师 30% 的质量增益。
这是"训练用稀疏,推理用 dense"的新范式——
· 训练时——用 Switch 的便宜容量训出一个大稀疏模型
· 蒸馏时——把它的知识压缩到小 dense 模型
· 推理时——只部署 dense 模型——每台设备都能跑
"30% 质量增益保留"听起来不多——但对于 99% 压缩比来说是巨大的。例如——如果大稀疏模型比 dense baseline 提升 100%,那压缩后的小模型保留 30%——仍然显著好于原始 dense baseline。
这给边缘部署、移动设备、低延迟服务开了一条路。"训练用云,推理用本地"——这个范式后来被很多 LLM 工作借鉴。
论文最后展示万亿参数 Switch Transformer。结合数据并行 + 模型并行 + 专家并行三种并行——达到约 1 万亿参数。在 C4 上预训练——对比 T5-XXL 强基线(11 billion 参数 dense)——同时间到同质量阈值,Switch 快 4 倍。这是 MoE 谱系里的里程碑——稀疏激活模型第一次达到"参数千万亿级别"。从 1991 的几个专家到 2021 的万亿参数——30 年——MoE 完成了从"小众思想"到"前沿 LLM 必备"的完整转变。
论文最后一节展示万亿参数 Switch Transformer——"把神经语言模型的规模又推进一步"。
结合三种并行——
· 数据并行——多设备处理不同 batch
· 模型并行——一个层切到多设备
· 专家并行——不同专家放不同设备
组合让模型规模达到约 1 万亿参数。
实验——vs T5-XXL(11 billion 参数 dense)——在 C4 上预训练。结果——同时间到同质量阈值,Switch 比 T5-XXL 快 4 倍。
这是 MoE 谱系里的里程碑——稀疏激活模型第一次正式达到"参数万亿级"。
从1991 的几个专家到2017 的 137B 参数到2021 的万亿参数——30 年里 MoE 完成了从"小众思想"到"前沿 LLM 必备"的完整转变。
Switch 之后——2021 GLaM(Google 1.2T MoE LLM)、2023 GPT-4(被广泛认为是 MoE)、2024 Mixtral 8x7B(第一个流行开源 MoE)——每一个都建立在 Switch 奠定的工程范式上。
2021 年 1 月——
Switch Transformer 第一次达到万亿参数。
2024 年——
Mixtral 8x7B 把同样的范式开源给所有人。
2017 年的 Noam Shazeer——
2021 年的 Fedus + Zoph + Shazeer——
1991 年的 Jacobs + Jordan + Hinton——
34 年——同一个想法的进化。