Files
2nd/10_Wiki/Topics/AI_and_ML/Flash Attention.md
T
koriweb d8a80f6272 chore(wiki): dangling 링크 canonical 정규화 (768파일/1200건)
이름만 다른(표기 변형) [[위키링크]]를 대상 문서의 canonical 제목으로 치환해
끊겼던 1,200개 링크를 연결. 제목/파일명 정규화 일치만 적용하고 별칭 매칭은
과병합 위험으로 제외(애매성 가드). 원본은 _link_reconcile_backup/ 에 백업.
도구: Datacollect/scripts/link_reconcile_apply.mjs

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-08 12:24:15 +09:00

7.0 KiB
Raw Blame History

id, title, category, status, canonical_id, aliases, duplicate_of, source_trust_level, confidence_score, verification_status, tags, raw_sources, last_reinforced, github_commit, tech_stack
id title category status canonical_id aliases duplicate_of source_trust_level confidence_score verification_status tags raw_sources last_reinforced github_commit tech_stack
wiki-2026-0508-flash-attention Flash Attention 10_Wiki/Topics verified self
FlashAttention
FA2
FA3
IO-aware attention
Tri Dao
online softmax
none A 0.98 applied
transformer
attention
gpu
optimization
flash-attention
memory-efficient
2026-05-10 pending
language framework
CUDA / PyTorch flash-attn / xformers / vLLM

Flash Attention

매 한 줄

"매 attention 의 IO-aware tile-based exact algorithm". Tri Dao 2022 (FA1), 2023 (FA2), 2024 (FA3). 매 quadratic memory 의 fix — 매 O(N²) → 매 O(N) memory. 매 modern transformer 의 standard. 매 vLLM, xformers, native PyTorch.

매 핵심

매 problem (vanilla)

  • Standard attention: 매 O(N²) memory (매 N×N attention matrix).
  • HBM bandwidth: 매 bottleneck (>FLOPS).
  • Long context: 매 OOM.

매 solution (Flash)

  • Tile Q, K, V into blocks.
  • Online softmax: 매 incremental, no full matrix.
  • SRAM compute: 매 fast on-chip.
  • Recomputation: 매 backward 의 의 의 trade compute for memory.

매 versions

  • FA1 (2022): 매 baseline.
  • FA2 (2023): 매 better parallelism, 2x faster.
  • FA3 (2024): 매 H100-optimized, async.

매 응용

  1. All transformer training.
  2. Long-context (100K+).
  3. Inference (vLLM, TGI).
  4. Multi-query / GQA.
  5. Sparse / sliding window.

💻 패턴

PyTorch native (FA built-in)

import torch
import torch.nn.functional as F

# 매 PyTorch 2.0+ scaled_dot_product_attention auto-uses Flash if eligible
out = F.scaled_dot_product_attention(q, k, v, is_causal=True)

flash-attn (Tri Dao package)

from flash_attn import flash_attn_func, flash_attn_varlen_func

# 매 standard
out = flash_attn_func(q, k, v, dropout_p=0.0, causal=True)

# 매 variable length (no padding waste)
out = flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=True)

xformers

from xformers.ops import memory_efficient_attention
out = memory_efficient_attention(q, k, v, attn_bias=causal_mask)

vLLM (paged attention serving)

from vllm import LLM, SamplingParams
llm = LLM(model='meta-llama/Llama-3-8B', dtype='bfloat16')
# 매 internally uses Flash + paged attention
outputs = llm.generate(['Hello'], SamplingParams(max_tokens=100))

Manual Flash-style (educational, simplified)

def flash_attention_simple(Q, K, V, block_size=64):
    """매 simplified — actual implementation 의 CUDA."""
    N = Q.shape[1]
    O = torch.zeros_like(Q)
    L = torch.zeros(Q.shape[:2])  # 매 max
    M = torch.full(Q.shape[:2], float('-inf'))  # 매 normalize
    
    for j in range(0, N, block_size):
        Kj = K[:, j:j+block_size]
        Vj = V[:, j:j+block_size]
        for i in range(0, N, block_size):
            Qi = Q[:, i:i+block_size]
            Sij = Qi @ Kj.transpose(-1, -2)
            Mij = Sij.max(dim=-1, keepdim=True).values
            Mi_new = torch.maximum(M[:, i:i+block_size, None], Mij)
            Pij = torch.exp(Sij - Mi_new)
            # 매 online normalization
            scale = torch.exp(M[:, i:i+block_size, None] - Mi_new)
            O[:, i:i+block_size] = O[:, i:i+block_size] * scale + Pij @ Vj
            M[:, i:i+block_size] = Mi_new.squeeze(-1)
    return O / L  # 매 simplified

Sliding window (Mistral-style)

from flash_attn import flash_attn_func
out = flash_attn_func(q, k, v, window_size=(window_left, 0), causal=True)

Grouped Query Attention (GQA)

class GQA(nn.Module):
    def __init__(self, dim, n_heads, n_kv_heads):
        super().__init__()
        self.q = nn.Linear(dim, n_heads * head_dim)
        self.k = nn.Linear(dim, n_kv_heads * head_dim)
        self.v = nn.Linear(dim, n_kv_heads * head_dim)
    
    def forward(self, x):
        q = self.q(x).view(...)
        k = self.k(x).view(...).repeat_interleave(n_heads // n_kv_heads, dim=2)
        v = self.v(x).view(...).repeat_interleave(n_heads // n_kv_heads, dim=2)
        return flash_attn_func(q, k, v, causal=True)

KV cache (inference)

# 매 paged attention
class PagedKVCache:
    def __init__(self, n_layers, max_seqs, block_size=16):
        self.blocks = {}  # 매 logical block → physical
        self.block_size = block_size
    
    def append(self, seq_id, k_block, v_block):
        physical = self.allocate_block()
        self.blocks[(seq_id, len(self.blocks))] = physical
        # 매 → flash_attn_with_kvcache

Backward (recomputation)

# 매 forward 의 small statistics + recompute on backward
# 매 native to flash_attn — automatic
out = flash_attn_func(q, k, v).backward()

Compile + Flash

# 매 PyTorch 2.x compile 의 fuse
model = torch.compile(model)
# 매 internally uses sdpa (Flash if available)

Detect Flash availability

def has_flash():
    try:
        from flash_attn import flash_attn_func
        return torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8
    except ImportError:
        return False

H100 / FA3

# 매 fa3 (2024) — H100 hopper async
from flash_attn_interface import flash_attn_func
# 매 same API, 1.5-2x faster on H100

Mask custom (block-sparse)

# 매 매 custom mask 의 efficient 의 X
# 매 fully sparse (e.g., longformer global+local) → flash-attn variants
from flash_attn.flash_attn_triton import flash_attn_func
out = flash_attn_func(q, k, v, custom_block_mask)

vLLM serving

python -m vllm.entrypoints.openai.api_server \
  --model meta-llama/Llama-3-8B \
  --dtype bfloat16 \
  --max-model-len 32768

매 결정 기준

상황 Approach
Default training PyTorch sdpa (auto)
Long context flash_attn_varlen
Production serving vLLM (paged)
Custom mask xformers / flash-attn variants
H100 FA3
Mobile / non-CUDA Use math fallback

기본값: 매 PyTorch sdpa + 매 vLLM serving + 매 GQA + 매 paged KV cache + 매 H100 FA3.

🔗 Graph

🤖 LLM 활용

언제: 매 모든 transformer training/inference. 언제 X: 매 non-CUDA (mobile).

안티패턴

  • Manual attention loop: 매 slow.
  • Pad to max in batch: 매 use varlen.
  • No KV cache: 매 inference quadratic.
  • Old non-Flash 의 prod: 매 cost ↑.

🧪 검증 / 중복

  • Verified (Dao 2022/2023/2024 FA papers, vLLM Kwon 2023).
  • 신뢰도 A.

🕓 Changelog

날짜 변경
2026-04-20 Auto-reinforced
2026-05-08 Phase 1
2026-05-10 Manual cleanup — algorithm + 매 PyTorch / flash-attn / vLLM / GQA / FA3 code