LLM面面观之MoE
1. 背景
根据本qiang~最新的趋势观察,基于MoE架构的开源大模型越来越多,比如马斯克的Grok-1(314B), Qwen1.5-MoE-A2.7B等,因此想探究一下MoE里面的部分细节。
此文是本qiang~
针对大语言模型的MoE的整理,包括原理、流程及部分源码
。
2. MoE原理
MoE的流行源于”欧洲的OpenAI” Mistral AI发布的论文及模型《Mixtral of Experts》,评测集上的效果吊打众多开源模型,如Llama 2 70B和GPT3.5。
《Mixtral of Experts》基础模型使用的是Mistral AI自研的Mistral 7B,该模型的特点包括:滑窗注意力(Sliding Window Aattention), 滚动缓冲区缓存(Rolling Buffer Cache)以及预填充-分块(Pre-fill and Chunking),具体细节可以查阅文末的论文地址。
本文以《Mixtral of Experts》为引子,探究MoE的相关细节,MoE的原理如下图所示:
图2.1 MoE的原理
(1) Transformers架构中的每一层中的FFN网络均替换为了8个FFN(专家),且由一个网关路由(gate router)进行控制
(2) 针对每一个token,每一层的网关路由仅选择其中的2个FFN(专家)来处理当前状态并进行加权输出
(3) 结果就是,每一个token访问了47B参数,但是在推理阶段仅仅使用了13B的激活参数(即,只使用2个专家,冻结其他6个专家)。
(4) 与Dropout机制对比,Dropout让部分神经元失活,而MoE是让部分专家失活。
3. 源码
本qiang~研读并尝试执行了Mistral官网的github推理代码,该代码框架非常适合新手,无他,只因其几乎只是在torch上层做的封装,很少引擎其他第三方库,不像transformers,功能强大,但不适合新手研读代码…
为了普适性,下面的代码截取了transformers框架中的代码。
首先看下通用Transformers中FFN中的代码模块,代码位置在transformers.models.mistral.modeling_mistral, 主要流程是:
(1) 先经过gate_proj和up_proj的2个[hidden_size, intermediate_size]的线性转换
(2) 使用激活函数对gate_proj进行激活
(3) 二者的内积再经过down_proj线性转换。
1 classMistralMLP(nn.Module):2 def __init__(self, config):3 super().__init__()4 self.config =config5 self.hidden_size =config.hidden_size6 self.intermediate_size =config.intermediate_size7 self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)8 self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)9 self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)10 self.act_fn =ACT2FN[config.hidden_act]11 12 defforward(self, x):13 return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
再来看下MoE中的专家模块,代码位置在transformers.models.mixtral.modeling_mixtral,主要流程是:
(1) 首先经过网关路由self.gate
(2) 然后选择其中2个专家,并归一化
(3) 之后遍历每个专家网络,并按照expert_mask进行筛选
(4) 如果expert_mask有值,则选择指定部分的隐藏层进行FFN操作,且输出结果进行加权
(5) 最后原地增加先前初始化的最终结果变量final_hidden_states
classMixtralSparseMoeBlock(nn.Module):def __init__(self, config):
super().__init__()
self.hidden_dim=config.hidden_size
self.ffn_dim=config.intermediate_size
self.num_experts=config.num_local_experts
self.top_k=config.num_experts_per_tok#gating self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)
self.experts= nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ inrange(self.num_experts)])def forward(self, hidden_states: torch.Tensor) ->torch.Tensor:""" """batch_size, sequence_length, hidden_dim=hidden_states.shape
hidden_states= hidden_states.view(-1, hidden_dim)#router_logits: (batch * sequence_length, n_experts) router_logits =self.gate(hidden_states)
routing_weights= F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts= torch.topk(routing_weights, self.top_k, dim=-1)
routing_weights/= routing_weights.sum(dim=-1, keepdim=True)#we cast back to the input dtype routing_weights =routing_weights.to(hidden_states.dtype)
final_hidden_states=torch.zeros(
(batch_size* sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
)#One hot encode the selected experts to create an expert mask #this will be used to easily index which expert is going to be sollicitated expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)#Loop over all available experts in the model and perform the computation on each expert for expert_idx inrange(self.num_experts):
expert_layer=self.experts[expert_idx]
idx, top_x=torch.where(expert_mask[expert_idx])if top_x.shape[0] ==0:continue #in torch it is faster to index using lists than torch tensors top_x_list =top_x.tolist()
idx_list=idx.tolist()#Index the correct hidden states and compute the expert hidden state for #the current expert. We need to make sure to multiply the output hidden #states by `routing_weights` on the corresponding tokens (top-1 and top-2) current_state = hidden_states[None, top_x_list].reshape(-1, hidden_dim)
current_hidden_states= expert_layer(current_state) *routing_weights[top_x_list, idx_list, None]#However `index_add_` only support torch tensors for indexing so we'll use #the `top_x` tensor here. final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states=final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)return final_hidden_states, router_logits
其中MixtralBlockSparseTop2MLP代码如下,可以看到和传统MistralMLP内容完全一致。
classMixtralBlockSparseTop2MLP(nn.Module):def __init__(self, config: MixtralConfig):
super().__init__()
self.ffn_dim=config.intermediate_size
self.hidden_dim=config.hidden_size
self.w1= nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.w2= nn.Linear(self.ffn_dim, self.hidden_dim, bias=False)
self.w3= nn.Linear(self.hidden_dim, self.ffn_dim, bias=False)
self.act_fn=ACT2FN[config.hidden_act]defforward(self, hidden_states):
current_hidden_states= self.act_fn(self.w1(hidden_states)) *self.w3(hidden_states)
current_hidden_states=self.w2(current_hidden_states)return current_hidden_states
4. MoE微调
由于MoE只是将每一层的FFN改变为了每一层的gate网关路由+8个FFN专家,且gate网关路由和8个专家内部均为线性运算,所以可以无缝地结合LoRA、QLoRA进行指令微调。
可以参考开源项目:https://github.com/yangjianxin1/Firefly
5. 答疑解惑
(1) 问:MoE 8*7B的模型是56B参数?
答:MoE 8*7B的参数量是47B,而不是56B,原因是每一层除了8个专家网络外,其他层均是复用的。
(2) 问:MoE的基础模型是Mistral 7B?
答:不是,MoE的模型架构与Mistral 7B相同,但其中的FFN替换为了8个FFN,且MoE是基于多语言数据集预训练而来的。
(3) MoE的稀疏性(sparse)体现在哪里?
答:在训练和推理时,同时只有两个专家网络会被激活,进行前向计算,其它专家网络处于失活状态。
6. 总结
一句话足矣~
本文主要针对大语言模型的MoE,包括原理及部分源码。
此外,建议大家可以针对源码进行运行,关于源码,欢迎大家一块交流。
7. 参考
(1) Mistral 7B:
https://arxiv.org/pdf/2310.06825v1.pdf
(2) MoE:
https://arxiv.org/pdf/2401.04088v1.pdf
(3) MoE开源指令微调框架Firefly:
https://github.com/yangjianxin1/Firefly