Mamba em Ação
Implementação simplificada do Selective State Space Model. Este código mostra como Mamba processa sequências com complexidade linear mantendo performance competitiva.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from einops import rearrange, repeat
class MambaBlock(nn.Module):
"""Selective State Space Model (S4) Block"""
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, dt_rank="auto"):
super().__init__()
self.d_model = d_model
self.d_state = d_state
self.d_conv = d_conv
self.expand = expand
self.d_inner = int(self.expand * self.d_model)
if dt_rank == "auto":
dt_rank = math.ceil(self.d_model / 16)
self.dt_rank = dt_rank
# Linear projections
self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=False)
# Convolution
self.conv1d = nn.Conv1d(
in_channels=self.d_inner,
out_channels=self.d_inner,
bias=True,
kernel_size=d_conv,
groups=self.d_inner,
padding=d_conv - 1
)
# SSM parameters
self.x_proj = nn.Linear(self.d_inner, self.dt_rank + self.d_state * 2, bias=False)
self.dt_proj = nn.Linear(self.dt_rank, self.d_inner, bias=True)
# Initialize special dt projection
dt_init_std = self.dt_rank**-0.5 * self.expand
nn.init.uniform_(self.dt_proj.weight, -dt_init_std, dt_init_std)
# S4D real initialization
A = repeat(
torch.arange(1, self.d_state + 1, dtype=torch.float32),
'n -> d n',
d=self.d_inner
).contiguous()
A_log = torch.log(A)
self.register_buffer("A_log", A_log)
# D parameter
self.D = nn.Parameter(torch.ones(self.d_inner))
# Output projection
self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=False)
def selective_scan(self, u, delta, A, B, C, D):
"""Performs the selective scan algorithm (S4)"""
batch, length, d_in = u.shape
d_state = A.shape[-1]
# Discretize continuous parameters
deltaA = torch.exp(delta.unsqueeze(-1) * A)
deltaB_u = delta.unsqueeze(-1) * B.unsqueeze(2) * u.unsqueeze(-1)
# Perform scan
x = torch.zeros((batch, d_in, d_state), device=u.device)
ys = []
for i in range(length):
x = deltaA[:, i] * x + deltaB_u[:, i]
y = (C[:, i] * x).sum(dim=-1)
ys.append(y)
y = torch.stack(ys, dim=1)
y = y + D * u
return y
def forward(self, x):
"""Mamba block forward pass
Args:
x: (batch, length, d_model)
Returns:
output: (batch, length, d_model)
"""
batch, length, _ = x.shape
# Linear projection and split
x_and_res = self.in_proj(x)
x, res = x_and_res.chunk(2, dim=-1)
# Convolution
x = rearrange(x, 'b l d -> b d l')
x = self.conv1d(x)[:, :, :length]
x = rearrange(x, 'b d l -> b l d')
# Non-linearity
x = F.silu(x)
# SSM
x_proj = self.x_proj(x)
delta, B, C = x_proj.split([self.dt_rank, self.d_state, self.d_state], dim=-1)
delta = F.softplus(self.dt_proj(delta))
# Get A from buffer
A = -torch.exp(self.A_log)
# Selective scan
y = self.selective_scan(x, delta, A, B, C, self.D)
# Gating
y = y * F.silu(res)
# Output projection
output = self.out_proj(y)
return output
class Mamba(nn.Module):
"""Full Mamba model for sequence modeling"""
def __init__(self, vocab_size, d_model=768, n_layers=24, d_state=16):
super().__init__()
self.d_model = d_model
# Token embeddings
self.embedding = nn.Embedding(vocab_size, d_model)
# Stack of Mamba blocks
self.layers = nn.ModuleList([
MambaBlock(d_model, d_state=d_state)
for _ in range(n_layers)
])
# Final norm and output
self.norm = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
# Tie weights
self.lm_head.weight = self.embedding.weight
def forward(self, input_ids):
"""
Args:
input_ids: (batch, length) tensor of token indices
Returns:
logits: (batch, length, vocab_size) tensor of predictions
"""
# Embed tokens
x = self.embedding(input_ids)
# Apply Mamba blocks
for layer in self.layers:
x = x + layer(x) # Residual connection
# Final norm and prediction
x = self.norm(x)
logits = self.lm_head(x)
return logits
# Exemplo de uso
def demo():
# Configuração
vocab_size = 50000
batch_size = 2
seq_length = 10000 # Sequência muito longa!
# Criar modelo
model = Mamba(vocab_size=vocab_size, d_model=768, n_layers=24)
# Dados de exemplo
input_ids = torch.randint(0, vocab_size, (batch_size, seq_length))
# Forward pass
with torch.no_grad():
output = model(input_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {output.shape}")
print(f"Memory usage: Linear with sequence length!")
print(f"Can handle sequences of 1M+ tokens easily")
if __name__ == "__main__":
demo()