Flash Attention v3 em Produção
Implementação e uso de Flash Attention com PyTorch. Este código mostra como integrar Flash Attention em modelos existentes para aceleração imediata.
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math
# Flash Attention v3 - Versão simplificada para demonstração
# Na prática, use a biblioteca oficial: pip install flash-attn
try:
from flash_attn import flash_attn_func, flash_attn_qkvpacked_func
FLASH_AVAILABLE = True
except ImportError:
FLASH_AVAILABLE = False
print("Flash Attention not available. Install with: pip install flash-attn")
class FlashAttention(nn.Module):
"""Flash Attention v3 implementation wrapper"""
def __init__(
self,
embed_dim: int,
num_heads: int,
dropout: float = 0.0,
causal: bool = False,
window_size: Optional[Tuple[int, int]] = None,
alibi_slopes: Optional[torch.Tensor] = None,
deterministic: bool = False
):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.dropout = dropout
self.causal = causal
self.window_size = window_size
self.alibi_slopes = alibi_slopes
self.deterministic = deterministic
# Scaling factor
self.scale = self.head_dim ** -0.5
# Projections
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=False)
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=False)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""
Args:
x: Input tensor (batch_size, seq_len, embed_dim)
attention_mask: Optional attention mask
position_ids: Optional position IDs for RoPE
Returns:
Output tensor (batch_size, seq_len, embed_dim)
"""
batch_size, seq_len, _ = x.shape
# QKV projection
qkv = self.qkv_proj(x)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
if FLASH_AVAILABLE:
# Use Flash Attention kernel
# Rearrange for flash attention format
qkv = qkv.permute(2, 0, 1, 3, 4) # (3, B, L, H, D)
q, k, v = qkv[0], qkv[1], qkv[2]
# Flash Attention forward
attn_output = flash_attn_func(
q, k, v,
dropout_p=self.dropout if self.training else 0.0,
causal=self.causal,
window_size=self.window_size,
alibi_slopes=self.alibi_slopes,
deterministic=self.deterministic,
)
else:
# Fallback to standard attention
attn_output = self._standard_attention(qkv, attention_mask)
# Reshape and output projection
attn_output = attn_output.reshape(batch_size, seq_len, self.embed_dim)
output = self.out_proj(attn_output)
return output
def _standard_attention(
self,
qkv: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Standard attention as fallback when Flash Attention not available"""
batch_size, seq_len = qkv.shape[:2]
q, k, v = qkv.chunk(3, dim=2)
q = q.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
k = k.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
v = v.reshape(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Scaled dot-product attention
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if self.causal:
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=qkv.device),
diagonal=1
).bool()
attn_weights.masked_fill_(causal_mask, float('-inf'))
if attention_mask is not None:
attn_weights = attn_weights + attention_mask
attn_weights = F.softmax(attn_weights, dim=-1)
attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)
attn_output = torch.matmul(attn_weights, v)
attn_output = attn_output.transpose(1, 2).contiguous()
return attn_output
class FlashTransformerBlock(nn.Module):
"""Transformer block using Flash Attention"""
def __init__(
self,
embed_dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
dropout: float = 0.0,
causal: bool = False,
):
super().__init__()
self.norm1 = nn.LayerNorm(embed_dim)
self.attn = FlashAttention(
embed_dim=embed_dim,
num_heads=num_heads,
dropout=dropout,
causal=causal,
)
self.norm2 = nn.LayerNorm(embed_dim)
mlp_hidden_dim = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mlp_hidden_dim, embed_dim),
nn.Dropout(dropout),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# Self-attention with residual
x = x + self.attn(self.norm1(x))
# MLP with residual
x = x + self.mlp(self.norm2(x))
return x
class FlashGPT(nn.Module):
"""GPT-style model using Flash Attention"""
def __init__(
self,
vocab_size: int,
max_seq_len: int = 2048,
embed_dim: int = 768,
num_heads: int = 12,
num_layers: int = 12,
mlp_ratio: float = 4.0,
dropout: float = 0.1,
):
super().__init__()
self.max_seq_len = max_seq_len
self.embed_dim = embed_dim
# Token and position embeddings
self.token_embed = nn.Embedding(vocab_size, embed_dim)
self.pos_embed = nn.Embedding(max_seq_len, embed_dim)
self.dropout = nn.Dropout(dropout)
# Transformer blocks with Flash Attention
self.blocks = nn.ModuleList([
FlashTransformerBlock(
embed_dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
dropout=dropout,
causal=True, # Causal attention for autoregressive
)
for _ in range(num_layers)
])
# Output
self.norm = nn.LayerNorm(embed_dim)
self.lm_head = nn.Linear(embed_dim, vocab_size, bias=False)
# Weight tying
self.token_embed.weight = self.lm_head.weight
# Initialize weights
self.apply(self._init_weights)
def _init_weights(self, module):
if isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if module.bias is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, seq_len = input_ids.shape
# Token embeddings
token_embeds = self.token_embed(input_ids)
# Position embeddings
position_ids = torch.arange(seq_len, device=input_ids.device)
pos_embeds = self.pos_embed(position_ids)
# Combine embeddings
x = self.dropout(token_embeds + pos_embeds)
# Transformer blocks
for block in self.blocks:
x = block(x)
# Output
x = self.norm(x)
logits = self.lm_head(x)
return logits
# Benchmark comparison
def benchmark_flash_attention():
"""Compare Flash Attention vs Standard Attention"""
import time
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Configuration
batch_size = 4
seq_lengths = [512, 1024, 2048, 4096, 8192]
embed_dim = 1024
num_heads = 16
print("\n" + "="*60)
print("Flash Attention v3 Benchmark")
print("="*60)
for seq_len in seq_lengths:
# Create random input
x = torch.randn(batch_size, seq_len, embed_dim).to(device)
# Flash Attention
flash_attn = FlashAttention(embed_dim, num_heads).to(device)
if device.type == 'cuda':
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
_ = flash_attn(x)
torch.cuda.synchronize()
flash_time = time.time() - start
# Memory usage
memory_allocated = torch.cuda.memory_allocated() / 1024**3
print(f"\nSequence Length: {seq_len}")
print(f"Time: {flash_time*1000:.2f}ms")
print(f"Memory: {memory_allocated:.2f} GB")
print(f"Throughput: {batch_size * seq_len / flash_time:.0f} tokens/sec")
else:
print("\nGPU not available for benchmarking")
break
# Example usage
if __name__ == "__main__":
print("\nFlash Attention v3 Implementation Example")
print("="*50)
# Create model
model = FlashGPT(
vocab_size=50000,
max_seq_len=8192, # Long context!
embed_dim=1024,
num_heads=16,
num_layers=24
)
# Count parameters
total_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {total_params/1e6:.1f}M")
# Example forward pass
batch_size = 2
seq_len = 4096 # Long sequence
input_ids = torch.randint(0, 50000, (batch_size, seq_len))
with torch.no_grad():
output = model(input_ids)
print(f"Input shape: {input_ids.shape}")
print(f"Output shape: {output.shape}")
print(f"\nFlash Attention enables:")
print("- 10x faster training")
print("- 20x less memory usage")
print("- 100K+ token contexts")
print("- Exact attention (no approximation)")
# Run benchmark if requested
# benchmark_flash_attention()