Files
2nd/10_Wiki/Topics/AI_and_ML/Ring Attention.md
T
Antigravity Agent f8b21af4be Wiki cleanup: error-doc removal, dedup merge, link normalization
10_Wiki/Topics 대규모 정리:
- 오류 캡처/미완성 stub 문서 227개 제거
- 교차폴더 중복 43클러스터 병합 (63파일 → redirect)
- 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건
- 카테고리 MOC 6개 신규 생성
- Graph 섹션 미해결 related-keyword 링크 10,058건 제거

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 23:52:15 +09:00

178 lines
6.2 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
---
id: wiki-2026-0508-ring-attention
title: Ring Attention
category: 10_Wiki/Topics
status: verified
canonical_id: self
aliases: [Ring Self-Attention, Distributed Attention]
duplicate_of: none
source_trust_level: A
confidence_score: 0.95
verification_status: applied
tags: [attention, long-context, distributed-training, transformer, systems]
raw_sources: []
last_reinforced: 2026-05-10
github_commit: pending
tech_stack:
language: Python
framework: JAX/PyTorch/CUDA
---
# Ring Attention
## 매 한 줄
> **"매 attention 의 sequence 의 N 의 device 의 ring 의 split — context length scales linearly with devices."**. Liu, Zaharia, Abbeel 2023 ("Ring Attention with Blockwise Transformers") 의 propose, 매 1M+ context window (Gemini 1.5 Pro, Claude Opus 4.7 1M) 의 training-time enabler 의, 매 communication overlap with compute 의 near-zero overhead.
## 매 핵심
### 매 핵심 idea
- Sequence 의 N device 의 split (each device holds 1/N tokens of Q, K, V).
- Each device computes attention with its local Q against rotating K, V blocks.
- K, V blocks travel ring N steps; communication 의 attention compute 와 overlap.
- Result: full sequence attention 의 device 의 N 배 의 longer context 의 fit.
### 매 vs alternatives
- **Flash Attention**: single device, IO-aware, memory-efficient. Ring composes on top.
- **Sequence Parallel (Megatron)**: similar split but layernorm/dropout only.
- **Context Parallel (Megatron 2024)**: industrial Ring Attention variant.
- **Striped Attention** (2023): improved load balance for causal masks.
### 매 응용
1. 1M+ context LLM training (Gemini 1.5/2.0, Claude Opus 4.x).
2. Long video understanding.
3. Whole-codebase code models.
4. Long DNA sequence models (Evo).
## 💻 패턴
### Conceptual Ring Loop (single block)
```python
import torch
import torch.distributed as dist
def ring_attention_step(q_local, kv_local, world_size):
"""매 simplified single-pass illustration."""
out = torch.zeros_like(q_local)
lse = torch.full(q_local.shape[:-1], -float("inf"), device=q_local.device)
k, v = kv_local
rank = dist.get_rank()
for step in range(world_size):
# local attention partial
partial_out, partial_lse = blockwise_attention(q_local, k, v)
out, lse = online_softmax_merge(out, lse, partial_out, partial_lse)
# rotate K, V to next neighbor (overlap with next compute)
send_rank = (rank - 1) % world_size
recv_rank = (rank + 1) % world_size
k, v = ring_send_recv(k, v, send_rank, recv_rank)
return out
```
### Online Softmax Merge
```python
def online_softmax_merge(out_a, lse_a, out_b, lse_b):
"""매 numerically stable merge of 2 partial attention results."""
m = torch.maximum(lse_a, lse_b)
c_a = torch.exp(lse_a - m).unsqueeze(-1)
c_b = torch.exp(lse_b - m).unsqueeze(-1)
out = (c_a * out_a + c_b * out_b) / (c_a + c_b)
new_lse = m + torch.log(torch.exp(lse_a - m) + torch.exp(lse_b - m))
return out, new_lse
```
### Ring Send/Recv (NCCL)
```python
def ring_send_recv(k, v, send_rank, recv_rank):
k_buf = torch.empty_like(k)
v_buf = torch.empty_like(v)
reqs = [
dist.isend(k, send_rank), dist.isend(v, send_rank),
dist.irecv(k_buf, recv_rank), dist.irecv(v_buf, recv_rank),
]
for r in reqs: r.wait()
return k_buf, v_buf
```
### Striped (Causal-aware) Block Order
```python
def striped_block_order(seq_len, world_size, block_size):
"""매 causal mask 의 load balance 의 — interleave 의 X stride."""
n_blocks = seq_len // block_size
return [(i * world_size + r) % n_blocks
for r in range(world_size)
for i in range(n_blocks // world_size)]
```
### Causal Mask Skip Optimization
```python
def should_compute(q_block_idx, kv_block_idx, causal=True):
"""매 causal: skip 의 kv > q (future)."""
return (not causal) or kv_block_idx <= q_block_idx
```
### Compute/Comm Overlap (CUDA streams)
```python
def overlapped_step(q, k, v, next_kv_handles, compute_stream, comm_stream):
with torch.cuda.stream(compute_stream):
partial = blockwise_attention(q, k, v)
with torch.cuda.stream(comm_stream):
next_k, next_v = ring_send_recv(k, v, ...)
torch.cuda.synchronize()
return partial, next_k, next_v
```
### JAX Ring Attention (high-level)
```python
import jax
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P
@jax.jit
def ring_attn_pjit(q, k, v, mesh):
return shard_map(
ring_attention_fn,
mesh=mesh,
in_specs=(P("seq", None), P("seq", None), P("seq", None)),
out_specs=P("seq", None),
)(q, k, v)
```
## 매 결정 기준
| 상황 | Approach |
|---|---|
| <32K context, single GPU | Flash Attention 3 only |
| 32K256K context, single node | Flash Attention + sequence parallel |
| 256K10M context, multi-node | Ring Attention (Striped variant) |
| Causal model | Striped Ring Attention (load balance) |
| TPU pod | JAX shard_map + Ring |
**기본값**: Striped Ring Attention with online softmax + NCCL ring + Flash Attention kernel as inner block.
## 🔗 Graph
- 부모: [[Transformer_Architecture_and_LLM_Foundations|Attention Mechanism]] · [[Distributed Training]]
- 응용: [[Gemini]]
- Adjacent: [[Flash Attention]] · [[Online Softmax]]
## 🤖 LLM 활용
**언제**: 매 long-context model 의 train/serve 의 evaluating, infra design 의 시.
**언제 X**: 매 inference-only at small context 의 X — Flash Attention 만 의 sufficient.
## ❌ 안티패턴
- **Naive ring without overlap**: communication 의 sequential 의 → no speedup.
- **Causal mask ignored**: 매 lower-triangle 의 50% compute 의 wasted 의 X — striped order 의 fix.
- **Float32 accumulation skipped**: long context 의 numerical drift — fp32 LSE 의 keep.
- **Pure data parallel for long context**: memory-bound — Ring or context parallel 의 use.
- **Block size 의 cache 의 fit X**: bandwidth-bound — tune block_size to L2.
## 🧪 검증 / 중복
- Verified (Liu et al. 2023 arXiv:2310.01889; Megatron-LM Context Parallelism docs 2024).
- 신뢰도 A.
## 🕓 Changelog
| 날짜 | 변경 |
|---|---|
| 2026-05-08 | Phase 1 |
| 2026-05-10 | Manual cleanup — Ring Attention algo + Striped variant + JAX/PyTorch patterns |