Files
2nd/10_Wiki/Topics/AI_and_ML/Ring Attention.md
T
Antigravity Agent f8b21af4be Wiki cleanup: error-doc removal, dedup merge, link normalization
10_Wiki/Topics 대규모 정리:
- 오류 캡처/미완성 stub 문서 227개 제거
- 교차폴더 중복 43클러스터 병합 (63파일 → redirect)
- 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건
- 카테고리 MOC 6개 신규 생성
- Graph 섹션 미해결 related-keyword 링크 10,058건 제거

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 23:52:15 +09:00

6.2 KiB
Raw Blame History

id, title, category, status, canonical_id, aliases, duplicate_of, source_trust_level, confidence_score, verification_status, tags, raw_sources, last_reinforced, github_commit, tech_stack
id title category status canonical_id aliases duplicate_of source_trust_level confidence_score verification_status tags raw_sources last_reinforced github_commit tech_stack
wiki-2026-0508-ring-attention Ring Attention 10_Wiki/Topics verified self
Ring Self-Attention
Distributed Attention
none A 0.95 applied
attention
long-context
distributed-training
transformer
systems
2026-05-10 pending
language framework
Python JAX/PyTorch/CUDA

Ring Attention

매 한 줄

"매 attention 의 sequence 의 N 의 device 의 ring 의 split — context length scales linearly with devices.". Liu, Zaharia, Abbeel 2023 ("Ring Attention with Blockwise Transformers") 의 propose, 매 1M+ context window (Gemini 1.5 Pro, Claude Opus 4.7 1M) 의 training-time enabler 의, 매 communication overlap with compute 의 near-zero overhead.

매 핵심

매 핵심 idea

  • Sequence 의 N device 의 split (each device holds 1/N tokens of Q, K, V).
  • Each device computes attention with its local Q against rotating K, V blocks.
  • K, V blocks travel ring N steps; communication 의 attention compute 와 overlap.
  • Result: full sequence attention 의 device 의 N 배 의 longer context 의 fit.

매 vs alternatives

  • Flash Attention: single device, IO-aware, memory-efficient. Ring composes on top.
  • Sequence Parallel (Megatron): similar split but layernorm/dropout only.
  • Context Parallel (Megatron 2024): industrial Ring Attention variant.
  • Striped Attention (2023): improved load balance for causal masks.

매 응용

  1. 1M+ context LLM training (Gemini 1.5/2.0, Claude Opus 4.x).
  2. Long video understanding.
  3. Whole-codebase code models.
  4. Long DNA sequence models (Evo).

💻 패턴

Conceptual Ring Loop (single block)

import torch
import torch.distributed as dist

def ring_attention_step(q_local, kv_local, world_size):
    """매 simplified single-pass illustration."""
    out = torch.zeros_like(q_local)
    lse = torch.full(q_local.shape[:-1], -float("inf"), device=q_local.device)

    k, v = kv_local
    rank = dist.get_rank()

    for step in range(world_size):
        # local attention partial
        partial_out, partial_lse = blockwise_attention(q_local, k, v)
        out, lse = online_softmax_merge(out, lse, partial_out, partial_lse)

        # rotate K, V to next neighbor (overlap with next compute)
        send_rank = (rank - 1) % world_size
        recv_rank = (rank + 1) % world_size
        k, v = ring_send_recv(k, v, send_rank, recv_rank)
    return out

Online Softmax Merge

def online_softmax_merge(out_a, lse_a, out_b, lse_b):
    """매 numerically stable merge of 2 partial attention results."""
    m = torch.maximum(lse_a, lse_b)
    c_a = torch.exp(lse_a - m).unsqueeze(-1)
    c_b = torch.exp(lse_b - m).unsqueeze(-1)
    out = (c_a * out_a + c_b * out_b) / (c_a + c_b)
    new_lse = m + torch.log(torch.exp(lse_a - m) + torch.exp(lse_b - m))
    return out, new_lse

Ring Send/Recv (NCCL)

def ring_send_recv(k, v, send_rank, recv_rank):
    k_buf = torch.empty_like(k)
    v_buf = torch.empty_like(v)
    reqs = [
        dist.isend(k, send_rank), dist.isend(v, send_rank),
        dist.irecv(k_buf, recv_rank), dist.irecv(v_buf, recv_rank),
    ]
    for r in reqs: r.wait()
    return k_buf, v_buf

Striped (Causal-aware) Block Order

def striped_block_order(seq_len, world_size, block_size):
    """매 causal mask 의 load balance 의 — interleave 의 X stride."""
    n_blocks = seq_len // block_size
    return [(i * world_size + r) % n_blocks
            for r in range(world_size)
            for i in range(n_blocks // world_size)]

Causal Mask Skip Optimization

def should_compute(q_block_idx, kv_block_idx, causal=True):
    """매 causal: skip 의 kv > q (future)."""
    return (not causal) or kv_block_idx <= q_block_idx

Compute/Comm Overlap (CUDA streams)

def overlapped_step(q, k, v, next_kv_handles, compute_stream, comm_stream):
    with torch.cuda.stream(compute_stream):
        partial = blockwise_attention(q, k, v)
    with torch.cuda.stream(comm_stream):
        next_k, next_v = ring_send_recv(k, v, ...)
    torch.cuda.synchronize()
    return partial, next_k, next_v

JAX Ring Attention (high-level)

import jax
from jax.experimental.shard_map import shard_map
from jax.sharding import PartitionSpec as P

@jax.jit
def ring_attn_pjit(q, k, v, mesh):
    return shard_map(
        ring_attention_fn,
        mesh=mesh,
        in_specs=(P("seq", None), P("seq", None), P("seq", None)),
        out_specs=P("seq", None),
    )(q, k, v)

매 결정 기준

상황 Approach
<32K context, single GPU Flash Attention 3 only
32K256K context, single node Flash Attention + sequence parallel
256K10M context, multi-node Ring Attention (Striped variant)
Causal model Striped Ring Attention (load balance)
TPU pod JAX shard_map + Ring

기본값: Striped Ring Attention with online softmax + NCCL ring + Flash Attention kernel as inner block.

🔗 Graph

🤖 LLM 활용

언제: 매 long-context model 의 train/serve 의 evaluating, infra design 의 시. 언제 X: 매 inference-only at small context 의 X — Flash Attention 만 의 sufficient.

안티패턴

  • Naive ring without overlap: communication 의 sequential 의 → no speedup.
  • Causal mask ignored: 매 lower-triangle 의 50% compute 의 wasted 의 X — striped order 의 fix.
  • Float32 accumulation skipped: long context 의 numerical drift — fp32 LSE 의 keep.
  • Pure data parallel for long context: memory-bound — Ring or context parallel 의 use.
  • Block size 의 cache 의 fit X: bandwidth-bound — tune block_size to L2.

🧪 검증 / 중복

  • Verified (Liu et al. 2023 arXiv:2310.01889; Megatron-LM Context Parallelism docs 2024).
  • 신뢰도 A.

🕓 Changelog

날짜 변경
2026-05-08 Phase 1
2026-05-10 Manual cleanup — Ring Attention algo + Striped variant + JAX/PyTorch patterns