Files
2nd/10_Wiki/Topics/AI_and_ML/Flash Attention.md
T
koriweb d8a80f6272 chore(wiki): dangling 링크 canonical 정규화 (768파일/1200건)
이름만 다른(표기 변형) [[위키링크]]를 대상 문서의 canonical 제목으로 치환해
끊겼던 1,200개 링크를 연결. 제목/파일명 정규화 일치만 적용하고 별칭 매칭은
과병합 위험으로 제외(애매성 가드). 원본은 _link_reconcile_backup/ 에 백업.
도구: Datacollect/scripts/link_reconcile_apply.mjs

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-08 12:24:15 +09:00

233 lines
7.0 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
---
id: wiki-2026-0508-flash-attention
title: Flash Attention
category: 10_Wiki/Topics
status: verified
canonical_id: self
aliases: [FlashAttention, FA2, FA3, IO-aware attention, Tri Dao, online softmax]
duplicate_of: none
source_trust_level: A
confidence_score: 0.98
verification_status: applied
tags: [transformer, attention, gpu, optimization, flash-attention, memory-efficient]
raw_sources: []
last_reinforced: 2026-05-10
github_commit: pending
tech_stack:
language: CUDA / PyTorch
framework: flash-attn / xformers / vLLM
---
# Flash Attention
## 매 한 줄
> **"매 attention 의 IO-aware tile-based exact algorithm"**. Tri Dao 2022 (FA1), 2023 (FA2), 2024 (FA3). 매 quadratic memory 의 fix — 매 O(N²) → 매 O(N) memory. 매 modern transformer 의 standard. 매 vLLM, xformers, native PyTorch.
## 매 핵심
### 매 problem (vanilla)
- **Standard attention**: 매 O(N²) memory (매 N×N attention matrix).
- **HBM bandwidth**: 매 bottleneck (>FLOPS).
- **Long context**: 매 OOM.
### 매 solution (Flash)
- **Tile** Q, K, V into blocks.
- **Online softmax**: 매 incremental, no full matrix.
- **SRAM compute**: 매 fast on-chip.
- **Recomputation**: 매 backward 의 의 의 trade compute for memory.
### 매 versions
- **FA1** (2022): 매 baseline.
- **FA2** (2023): 매 better parallelism, 2x faster.
- **FA3** (2024): 매 H100-optimized, async.
### 매 응용
1. **All transformer training**.
2. **Long-context** (100K+).
3. **Inference** (vLLM, TGI).
4. **Multi-query / GQA**.
5. **Sparse / sliding window**.
## 💻 패턴
### PyTorch native (FA built-in)
```python
import torch
import torch.nn.functional as F
# 매 PyTorch 2.0+ scaled_dot_product_attention auto-uses Flash if eligible
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)
```
### flash-attn (Tri Dao package)
```python
from flash_attn import flash_attn_func, flash_attn_varlen_func
# 매 standard
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)
# 매 variable length (no padding waste)
out = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True)
```
### xformers
```python
from xformers.ops import memory_efficient_attention
out = memory_efficient_attention(q, k, v, attn_bias=causal_mask)
```
### vLLM (paged attention serving)
```python
from vllm import LLM, SamplingParams
llm = LLM(model='meta-llama/Llama-3-8B', dtype='bfloat16')
# 매 internally uses Flash + paged attention
outputs = llm.generate(['Hello'], SamplingParams(max_tokens=100))
```
### Manual Flash-style (educational, simplified)
```python
def flash_attention_simple(Q, K, V, block_size=64):
"""매 simplified — actual implementation 의 CUDA."""
N = Q.shape[1]
O = torch.zeros_like(Q)
L = torch.zeros(Q.shape[:2]) # 매 max
M = torch.full(Q.shape[:2], float('-inf')) # 매 normalize
for j in range(0, N, block_size):
Kj = K[:, j:j+block_size]
Vj = V[:, j:j+block_size]
for i in range(0, N, block_size):
Qi = Q[:, i:i+block_size]
Sij = Qi @ Kj.transpose(-1, -2)
Mij = Sij.max(dim=-1, keepdim=True).values
Mi_new = torch.maximum(M[:, i:i+block_size, None], Mij)
Pij = torch.exp(Sij - Mi_new)
# 매 online normalization
scale = torch.exp(M[:, i:i+block_size, None] - Mi_new)
O[:, i:i+block_size] = O[:, i:i+block_size] * scale + Pij @ Vj
M[:, i:i+block_size] = Mi_new.squeeze(-1)
return O / L # 매 simplified
```
### Sliding window (Mistral-style)
```python
from flash_attn import flash_attn_func
out = flash_attn_func(q, k, v, window_size=(window_left, 0), causal=True)
```
### Grouped Query Attention (GQA)
```python
class GQA(nn.Module):
def __init__(self, dim, n_heads, n_kv_heads):
super().__init__()
self.q = nn.Linear(dim, n_heads * head_dim)
self.k = nn.Linear(dim, n_kv_heads * head_dim)
self.v = nn.Linear(dim, n_kv_heads * head_dim)
def forward(self, x):
q = self.q(x).view(...)
k = self.k(x).view(...).repeat_interleave(n_heads // n_kv_heads, dim=2)
v = self.v(x).view(...).repeat_interleave(n_heads // n_kv_heads, dim=2)
return flash_attn_func(q, k, v, causal=True)
```
### KV cache (inference)
```python
# 매 paged attention
class PagedKVCache:
def __init__(self, n_layers, max_seqs, block_size=16):
self.blocks = {} # 매 logical block → physical
self.block_size = block_size
def append(self, seq_id, k_block, v_block):
physical = self.allocate_block()
self.blocks[(seq_id, len(self.blocks))] = physical
# 매 → flash_attn_with_kvcache
```
### Backward (recomputation)
```python
# 매 forward 의 small statistics + recompute on backward
# 매 native to flash_attn — automatic
out = flash_attn_func(q, k, v).backward()
```
### Compile + Flash
```python
# 매 PyTorch 2.x compile 의 fuse
model = torch.compile(model)
# 매 internally uses sdpa (Flash if available)
```
### Detect Flash availability
```python
def has_flash():
try:
from flash_attn import flash_attn_func
return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
except ImportError:
return False
```
### H100 / FA3
```python
# 매 fa3 (2024) — H100 hopper async
from flash_attn_interface import flash_attn_func
# 매 same API, 1.5-2x faster on H100
```
### Mask custom (block-sparse)
```python
# 매 매 custom mask 의 efficient 의 X
# 매 fully sparse (e.g., longformer global+local) → flash-attn variants
from flash_attn.flash_attn_triton import flash_attn_func
out = flash_attn_func(q, k, v, custom_block_mask)
```
### vLLM serving
```bash
python -m vllm.entrypoints.openai.api_server \
--model meta-llama/Llama-3-8B \
--dtype bfloat16 \
--max-model-len 32768
```
## 매 결정 기준
| 상황 | Approach |
|---|---|
| Default training | PyTorch sdpa (auto) |
| Long context | flash_attn_varlen |
| Production serving | vLLM (paged) |
| Custom mask | xformers / flash-attn variants |
| H100 | FA3 |
| Mobile / non-CUDA | Use math fallback |
**기본값**: 매 PyTorch sdpa + 매 vLLM serving + 매 GQA + 매 paged KV cache + 매 H100 FA3.
## 🔗 Graph
- 부모: [[Transformer]] · [[Attention Mechanism]]
- 변형: [[PagedAttention]] · [[Sliding-Window]] · [[GQA]]
- 응용: [[LLM_Optimization_and_Deployment_Strategies|vLLM]] · [[Long-Context]]
- Adjacent: [[LLM_Optimization_and_Deployment_Strategies|Quantization]] · [[Foundation-Models]]
## 🤖 LLM 활용
**언제**: 매 모든 transformer training/inference.
**언제 X**: 매 non-CUDA (mobile).
## ❌ 안티패턴
- **Manual attention loop**: 매 slow.
- **Pad to max in batch**: 매 use varlen.
- **No KV cache**: 매 inference quadratic.
- **Old non-Flash 의 prod**: 매 cost ↑.
## 🧪 검증 / 중복
- Verified (Dao 2022/2023/2024 FA papers, vLLM Kwon 2023).
- 신뢰도 A.
## 🕓 Changelog
| 날짜 | 변경 |
|---|---|
| 2026-04-20 | Auto-reinforced |
| 2026-05-08 | Phase 1 |
| 2026-05-10 | Manual cleanup — algorithm + 매 PyTorch / flash-attn / vLLM / GQA / FA3 code |