f8b21af4be
10_Wiki/Topics 대규모 정리: - 오류 캡처/미완성 stub 문서 227개 제거 - 교차폴더 중복 43클러스터 병합 (63파일 → redirect) - 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건 - 카테고리 MOC 6개 신규 생성 - Graph 섹션 미해결 related-keyword 링크 10,058건 제거 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
6.2 KiB
6.2 KiB
id, title, category, status, canonical_id, aliases, duplicate_of, source_trust_level, confidence_score, verification_status, tags, raw_sources, last_reinforced, github_commit, tech_stack
| id | title | category | status | canonical_id | aliases | duplicate_of | source_trust_level | confidence_score | verification_status | tags | raw_sources | last_reinforced | github_commit | tech_stack | |||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| wiki-2026-0508-ring-attention | Ring Attention | 10_Wiki/Topics | verified | self |
|
none | A | 0.95 | applied |
|
2026-05-10 | pending |
|
Ring Attention
매 한 줄
"매 attention 의 sequence 의 N 의 device 의 ring 의 split — context length scales linearly with devices.". Liu, Zaharia, Abbeel 2023 ("Ring Attention with Blockwise Transformers") 의 propose, 매 1M+ context window (Gemini 1.5 Pro, Claude Opus 4.7 1M) 의 training-time enabler 의, 매 communication overlap with compute 의 near-zero overhead.
매 핵심
매 핵심 idea
- Sequence 의 N device 의 split (each device holds 1/N tokens of Q, K, V).
- Each device computes attention with its local Q against rotating K, V blocks.
- K, V blocks travel ring N steps; communication 의 attention compute 와 overlap.
- Result: full sequence attention 의 device 의 N 배 의 longer context 의 fit.
매 vs alternatives
- Flash Attention: single device, IO-aware, memory-efficient. Ring composes on top.
- Sequence Parallel (Megatron): similar split but layernorm/dropout only.
- Context Parallel (Megatron 2024): industrial Ring Attention variant.
- Striped Attention (2023): improved load balance for causal masks.
매 응용
- 1M+ context LLM training (Gemini 1.5/2.0, Claude Opus 4.x).
- Long video understanding.
- Whole-codebase code models.
- Long DNA sequence models (Evo).
💻 패턴
Conceptual Ring Loop (single block)
import torch
import torch.distributed as dist
def ring_attention_step(q_local, kv_local, world_size):
"""매 simplified single-pass illustration."""
out = torch.zeros_like(q_local)
lse = torch.full(q_local.shape[:-1], -float("inf"), device=q_local.device)
k, v = kv_local
rank = dist.get_rank()
for step in range(world_size):
# local attention partial
partial_out, partial_lse = blockwise_attention(q_local, k, v)
out, lse = online_softmax_merge(out, lse, partial_out, partial_lse)
# rotate K, V to next neighbor (overlap with next compute)
send_rank = (rank - 1) % world_size
recv_rank = (rank + 1) % world_size
k, v = ring_send_recv(k, v, send_rank, recv_rank)
return out
Online Softmax Merge
def online_softmax_merge(out_a, lse_a, out_b, lse_b):
"""매 numerically stable merge of 2 partial attention results."""
m = torch.maximum(lse_a, lse_b)
c_a = torch.exp(lse_a - m).unsqueeze(-1)
c_b = torch.exp(lse_b - m).unsqueeze(-1)
out = (c_a * out_a + c_b * out_b) / (c_a + c_b)
new_lse = m + torch.log(torch.exp(lse_a - m) + torch.exp(lse_b - m))
return out, new_lse
Ring Send/Recv (NCCL)
def ring_send_recv(k, v, send_rank, recv_rank):
k_buf = torch.empty_like(k)
v_buf = torch.empty_like(v)
reqs = [
dist.isend(k, send_rank), dist.isend(v, send_rank),
dist.irecv(k_buf, recv_rank), dist.irecv(v_buf, recv_rank),
]
for r in reqs: r.wait()
return k_buf, v_buf
Striped (Causal-aware) Block Order
def striped_block_order(seq_len, world_size, block_size):
"""매 causal mask 의 load balance 의 — interleave 의 X stride."""
n_blocks = seq_len // block_size
return [(i * world_size + r) % n_blocks
for r in range(world_size)
for i in range(n_blocks // world_size)]
Causal Mask Skip Optimization
def should_compute(q_block_idx, kv_block_idx, causal=True):
"""매 causal: skip 의 kv > q (future)."""
return (not causal) or kv_block_idx <= q_block_idx
Compute/Comm Overlap (CUDA streams)
def overlapped_step(q, k, v, next_kv_handles, compute_stream, comm_stream):
with torch.cuda.stream(compute_stream):
partial = blockwise_attention(q, k, v)
with torch.cuda.stream(comm_stream):
next_k, next_v = ring_send_recv(k, v, ...)
torch.cuda.synchronize()
return partial, next_k, next_v
JAX Ring Attention (high-level)
import jax
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P
@jax.jit
def ring_attn_pjit(q, k, v, mesh):
return shard_map(
ring_attention_fn,
mesh=mesh,
in_specs=(P("seq", None), P("seq", None), P("seq", None)),
out_specs=P("seq", None),
)(q, k, v)
매 결정 기준
| 상황 | Approach |
|---|---|
| <32K context, single GPU | Flash Attention 3 only |
| 32K–256K context, single node | Flash Attention + sequence parallel |
| 256K–10M context, multi-node | Ring Attention (Striped variant) |
| Causal model | Striped Ring Attention (load balance) |
| TPU pod | JAX shard_map + Ring |
기본값: Striped Ring Attention with online softmax + NCCL ring + Flash Attention kernel as inner block.
🔗 Graph
- 부모: Transformer_Architecture_and_LLM_Foundations · Distributed Training
- 응용: Gemini
- Adjacent: Flash Attention · Online Softmax
🤖 LLM 활용
언제: 매 long-context model 의 train/serve 의 evaluating, infra design 의 시. 언제 X: 매 inference-only at small context 의 X — Flash Attention 만 의 sufficient.
❌ 안티패턴
- Naive ring without overlap: communication 의 sequential 의 → no speedup.
- Causal mask ignored: 매 lower-triangle 의 50% compute 의 wasted 의 X — striped order 의 fix.
- Float32 accumulation skipped: long context 의 numerical drift — fp32 LSE 의 keep.
- Pure data parallel for long context: memory-bound — Ring or context parallel 의 use.
- Block size 의 cache 의 fit X: bandwidth-bound — tune block_size to L2.
🧪 검증 / 중복
- Verified (Liu et al. 2023 arXiv:2310.01889; Megatron-LM Context Parallelism docs 2024).
- 신뢰도 A.
🕓 Changelog
| 날짜 | 변경 |
|---|---|
| 2026-05-08 | Phase 1 |
| 2026-05-10 | Manual cleanup — Ring Attention algo + Striped variant + JAX/PyTorch patterns |