Files
2nd/10_Wiki/Topics/Other/Just-In-Time (JIT).md
T
Antigravity Agent f8b21af4be Wiki cleanup: error-doc removal, dedup merge, link normalization
10_Wiki/Topics 대규모 정리:
- 오류 캡처/미완성 stub 문서 227개 제거
- 교차폴더 중복 43클러스터 병합 (63파일 → redirect)
- 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건
- 카테고리 MOC 6개 신규 생성
- Graph 섹션 미해결 related-keyword 링크 10,058건 제거

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 23:52:15 +09:00

187 lines
5.7 KiB
Markdown

---
id: wiki-2026-0508-just-in-time-jit
title: Just-In-Time (JIT)
category: 10_Wiki/Topics
status: verified
canonical_id: self
aliases: [JIT Compilation, Dynamic Compilation, Tracing JIT]
duplicate_of: none
source_trust_level: A
confidence_score: 0.95
verification_status: applied
tags: [compiler, optimization, runtime, performance, llvm]
raw_sources: []
last_reinforced: 2026-05-10
github_commit: pending
tech_stack:
language: python
framework: jax
---
# Just-In-Time (JIT)
## 매 한 줄
> **"매 compile 매 first call, 매 reuse 매 hot path"**. JIT compilation 매 source / bytecode / IR 의 native code 의 runtime translation — 매 profile-guided 의 hot region 의 optimize. 2026 ML 시대 매 JAX `jit`, PyTorch 2.x `torch.compile`, Mojo, JuliaLang 매 mainstream.
## 매 핵심
### 매 JIT 의 mechanics
- **Trace**: 매 input shape / dtype 의 capture 매 computational graph.
- **Specialize**: 매 fixed shapes 의 specialized kernel 의 generate.
- **Cache**: 매 (function, signature) → compiled artifact.
- **Recompile**: 매 shape change → cache miss → recompile (avoid in hot loop).
### 매 vs AOT
- **AOT (ahead-of-time)**: rustc, gcc — startup 빠름, 매 dynamic dispatch 부족.
- **JIT**: 매 runtime info 의 use → better inlining, 매 startup 의 cost.
- **Hybrid**: PyPy, V8, .NET — interpret first, JIT after N invocations.
### 매 ML JIT 의 specifics
- **Static shape**: JAX `jit` 매 traced shape 의 specialize — dynamic shape 매 retrace.
- **XLA / Triton backend**: 매 fused kernels — memory bandwidth dominant.
- **Compilation cache**: persistent disk cache 매 cold-start 의 mitigate.
### 매 응용
1. ML training loop (JAX, torch.compile).
2. Numerical Python (Numba `@njit`).
3. JavaScript engines (V8, JSC).
4. Database query plans (Snowflake, DuckDB).
## 💻 패턴
### Pattern 1: JAX jit (2026 standard)
```python
import jax
import jax.numpy as jnp
@jax.jit
def attention(q, k, v):
scores = jnp.einsum("bhqd,bhkd->bhqk", q, k) / jnp.sqrt(q.shape[-1])
weights = jax.nn.softmax(scores, axis=-1)
return jnp.einsum("bhqk,bhkd->bhqd", weights, v)
# First call: trace + compile (slow)
# Subsequent: cached (fast)
out = attention(q, k, v)
```
### Pattern 2: torch.compile (PyTorch 2.x)
```python
import torch
model = MyTransformer().cuda()
compiled = torch.compile(model, mode="reduce-overhead", fullgraph=True)
for batch in dataloader:
out = compiled(batch) # 매 first batch 매 slow, subsequent 매 fast
out.backward()
```
### Pattern 3: Static argnums (avoid retrace)
```python
from functools import partial
@partial(jax.jit, static_argnums=(1,))
def topk(logits, k):
return jax.lax.top_k(logits, k)
# 매 k=10 매 specialized — 매 k=20 매 separate compilation
topk(logits, 10)
topk(logits, 20) # new compile
```
### Pattern 4: Numba JIT (Python → LLVM)
```python
from numba import njit
import numpy as np
@njit(cache=True, fastmath=True)
def mandelbrot(c, max_iter=100):
z = 0.0 + 0.0j
for i in range(max_iter):
z = z * z + c
if z.real * z.real + z.imag * z.imag > 4.0:
return i
return max_iter
```
### Pattern 5: AOT cache 의 prewarm
```python
import os
os.environ["JAX_COMPILATION_CACHE_DIR"] = "/var/cache/jax"
import jax
jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0)
jax.config.update("jax_persistent_cache_min_compile_time_secs", 1.0)
# 매 first deployment 매 prewarm script 의 run — 매 next pods cold-start fast.
```
### Pattern 6: Recompilation detection
```python
import jax
from collections import Counter
class CompileCounter:
def __init__(self):
self.count = Counter()
def trace(self, fn_name: str, sig: tuple):
self.count[(fn_name, sig)] += 1
if self.count[(fn_name, sig)] > 3:
print(f"매 thrash: {fn_name} recompiled {self.count[(fn_name, sig)]} times")
# Usage: hook into jax.config or torch dynamo logger
```
### Pattern 7: Mojo JIT (2026)
```mojo
fn matmul[M: Int, N: Int, K: Int](a: Tensor, b: Tensor) -> Tensor:
# 매 compile-time specialization 매 shapes — 매 SIMD auto-vectorize.
var c = Tensor[DType.float32](M, N)
for i in range(M):
for j in range(N):
var s: Float32 = 0
for k in range(K):
s += a[i, k] * b[k, j]
c[i, j] = s
return c
```
## 매 결정 기준
| 상황 | Approach |
|---|---|
| Numerical Python tight loop | Numba `@njit`. |
| ML training | JAX `jit` 또는 `torch.compile`. |
| Variable shapes | Avoid JIT 또는 `dynamic=True`. |
| One-shot script | 매 JIT overhead 매 not worth. |
| Long-running server | JIT + persistent cache. |
**기본값**: ML 매 `torch.compile(mode="reduce-overhead")` 또는 `jax.jit`. Tight numerical loop 매 Numba.
## 🔗 Graph
- 부모: [[Performance-Optimization]]
- 변형: [[Tracing-JIT]]
- 응용: [[JAX]] · [[torch.compile]] · [[V8-Engine]]
- Adjacent: [[XLA]] · [[Triton]]
## 🤖 LLM 활용
**언제**: ML training/serving where compile cost amortizes (>100 calls), tight numerical loops, long-running services.
**언제 X**: One-shot scripts, code with constantly-changing shapes, debugging (use eager mode).
## ❌ 안티패턴
- **JIT in hot Python loop with varying shapes**: 매 retrace 매 every call — slower than eager.
- **No persistent cache**: 매 cold start 매 30s+ compile every deploy.
- **JIT debugging**: 매 stacktrace 매 useless — eager 의 disable JIT first.
- **Premature JIT**: profile first — 매 80% code 매 not bottleneck.
## 🧪 검증 / 중복
- Verified: JAX docs (2026), PyTorch 2.x docs, "Engineering a Compiler" (Cooper & Torczon), V8 design docs.
- 신뢰도 A.
## 🕓 Changelog
| 날짜 | 변경 |
|---|---|
| 2026-05-08 | Phase 1 |
| 2026-05-10 | Manual cleanup — full content with JAX/torch.compile/Numba/Mojo 2026 patterns |