MoE em PyTorch
Implementação completa de Mixture of Experts incluindo router com top-k selection, múltiplos expert networks, load balancing loss, e arquitetura completa estilo Mixtral. Inclui também implementação de Switch Transformer (1 expert por token).
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple
class Expert(nn.Module):
"""Single Expert - FFN especializado"""
def __init__(self, d_model: int, d_ff: int):
super().__init__()
self.w1 = nn.Linear(d_model, d_ff)
self.w2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(0.1)
def forward(self, x):
# FFN: x -> ReLU(xW1) -> Dropout -> W2
return self.w2(self.dropout(F.relu(self.w1(x))))
class TopKRouter(nn.Module):
"""Router que seleciona top-k experts por token"""
def __init__(self, d_model: int, num_experts: int, k: int = 2):
super().__init__()
self.k = k
self.num_experts = num_experts
# Router network: linear projection
self.router = nn.Linear(d_model, num_experts)
# Noise para load balancing (opcional)
self.noise_std = 1e-2
def forward(self, x: torch.Tensor, training: bool = False):
"""
Args:
x: [batch_size, seq_len, d_model]
Returns:
gates: [batch_size, seq_len, k] - pesos dos experts
indices: [batch_size, seq_len, k] - índices dos experts
"""
# Router logits
logits = self.router(x) # [B, L, num_experts]
# Adiciona noise durante treinamento (load balancing)
if training:
noise = torch.randn_like(logits) * self.noise_std
logits = logits + noise
# Softmax sobre experts
probs = F.softmax(logits, dim=-1)
# Seleciona top-k experts
gates, indices = torch.topk(probs, self.k, dim=-1)
# Renormaliza gates
gates = gates / gates.sum(dim=-1, keepdim=True)
return gates, indices
def compute_load_balancing_loss(self, logits: torch.Tensor):
"""Loss para balancear carga entre experts"""
# Queremos distribuição uniforme de tokens por expert
probs = F.softmax(logits, dim=-1)
# Fração de tokens por expert
expert_counts = probs.sum(dim=[0, 1]) # [num_experts]
expert_fracs = expert_counts / expert_counts.sum()
# Penaliza desvio de distribuição uniforme
target = 1.0 / self.num_experts
load_loss = ((expert_fracs - target) ** 2).sum()
return load_loss
class MixtureOfExperts(nn.Module):
"""Layer MoE completo - Mixtral-style"""
def __init__(
self,
d_model: int = 4096,
d_ff: int = 14336,
num_experts: int = 8,
k: int = 2, # Experts ativos por token
):
super().__init__()
self.num_experts = num_experts
self.k = k
# Cria experts
self.experts = nn.ModuleList([
Expert(d_model, d_ff) for _ in range(num_experts)
])
# Router
self.router = TopKRouter(d_model, num_experts, k)
def forward(self, x: torch.Tensor):
"""
Args:
x: [batch_size, seq_len, d_model]
Returns:
output: [batch_size, seq_len, d_model]
aux_loss: scalar - load balancing loss
"""
batch_size, seq_len, d_model = x.shape
# Router decide experts
gates, indices = self.router(x, training=self.training)
# gates: [B, L, k], indices: [B, L, k]
# Flatten para processar
x_flat = x.view(-1, d_model) # [B*L, d_model]
gates_flat = gates.view(-1, self.k) # [B*L, k]
indices_flat = indices.view(-1, self.k) # [B*L, k]
# Output acumulado
output_flat = torch.zeros_like(x_flat)
# Processa cada expert
for expert_idx in range(self.num_experts):
# Encontra tokens roteados para este expert
expert_mask = (indices_flat == expert_idx)
if not expert_mask.any():
continue
# Tokens e gates deste expert
token_indices = expert_mask.any(dim=1).nonzero(as_tuple=True)[0]
expert_input = x_flat[token_indices]
# Processa com expert
expert_output = self.experts[expert_idx](expert_input)
# Aplica gates e acumula
for k_idx in range(self.k):
mask_k = expert_mask[:, k_idx]
if mask_k.any():
gate_weights = gates_flat[mask_k, k_idx].unsqueeze(1)
output_flat[mask_k] += gate_weights * expert_output[mask_k[token_indices]]
# Reshape
output = output_flat.view(batch_size, seq_len, d_model)
# Load balancing loss
aux_loss = self.router.compute_load_balancing_loss(
self.router.router(x)
)
return output, aux_loss
class SwitchTransformer(nn.Module):
"""Switch Transformer: 1 expert por token"""
def __init__(self, d_model: int, num_experts: int, capacity_factor: float = 1.25):
super().__init__()
self.num_experts = num_experts
self.capacity_factor = capacity_factor
self.experts = nn.ModuleList([
Expert(d_model, d_model * 4) for _ in range(num_experts)
])
self.router = nn.Linear(d_model, num_experts)
def forward(self, x: torch.Tensor):
batch_size, seq_len, d_model = x.shape
# Router escolhe 1 expert por token
logits = self.router(x)
expert_indices = logits.argmax(dim=-1) # [B, L]
expert_probs = F.softmax(logits, dim=-1)
# Capacity: max tokens por expert
capacity = int(self.capacity_factor * seq_len / self.num_experts)
output = torch.zeros_like(x)
# Processa por expert com capacity limit
for expert_idx in range(self.num_experts):
mask = (expert_indices == expert_idx)
token_indices = mask.nonzero(as_tuple=True)
# Limita por capacity
if len(token_indices[0]) > capacity:
# Prioriza por probabilidade
token_probs = expert_probs[mask][:, expert_idx]
top_indices = token_probs.topk(capacity).indices
# Filtra
token_indices = (
token_indices[0][top_indices],
token_indices[1][top_indices]
)
if len(token_indices[0]) > 0:
expert_input = x[token_indices]
expert_output = self.experts[expert_idx](expert_input)
output[token_indices] = expert_output
return output
# Modelo completo estilo Mixtral
class MixtralBlock(nn.Module):
"""Transformer block com MoE FFN"""
def __init__(self, d_model=4096, n_heads=32, num_experts=8, k=2):
super().__init__()
# Self-attention (denso)
self.attn = nn.MultiheadAttention(
d_model, n_heads, batch_first=True
)
# MoE FFN (esparso)
self.moe = MixtureOfExperts(d_model, d_model * 4, num_experts, k)
# Layer norms
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
def forward(self, x):
# Self-attention com residual
attn_out, _ = self.attn(x, x, x)
x = x + attn_out
x = self.ln1(x)
# MoE FFN com residual
moe_out, aux_loss = self.moe(x)
x = x + moe_out
x = self.ln2(x)
return x, aux_loss
# Uso
if __name__ == "__main__":
# Mixtral-style MoE
batch_size, seq_len, d_model = 2, 512, 4096
x = torch.randn(batch_size, seq_len, d_model)
moe_layer = MixtureOfExperts(
d_model=4096,
d_ff=14336,
num_experts=8,
k=2 # 2 experts ativos por token
)
output, aux_loss = moe_layer(x)
print(f"Input: {x.shape}")
print(f"Output: {output.shape}")
print(f"Load balancing loss: {aux_loss.item():.4f}")
# Calcula FLOPs
dense_flops = batch_size * seq_len * d_model * 14336 * 2
sparse_flops = dense_flops * (2 / 8) # k / num_experts
print(f"\nFLOPs reduction: {dense_flops / sparse_flops:.1f}x")
# Switch Transformer
switch = SwitchTransformer(d_model=512, num_experts=64)
x_small = torch.randn(2, 128, 512)
output_switch = switch(x_small)
print(f"\nSwitch output: {output_switch.shape}")