--- id: wiki-2026-0508-just-in-time-jit title: Just-In-Time (JIT) category: 10_Wiki/Topics status: verified canonical_id: self aliases: [JIT Compilation, Dynamic Compilation, Tracing JIT] duplicate_of: none source_trust_level: A confidence_score: 0.95 verification_status: applied tags: [compiler, optimization, runtime, performance, llvm] raw_sources: [] last_reinforced: 2026-05-10 github_commit: pending tech_stack: language: python framework: jax --- # Just-In-Time (JIT) ## 매 한 줄 > **"매 compile 매 first call, 매 reuse 매 hot path"**. JIT compilation 매 source / bytecode / IR 의 native code 의 runtime translation — 매 profile-guided 의 hot region 의 optimize. 2026 ML 시대 매 JAX `jit`, PyTorch 2.x `torch.compile`, Mojo, JuliaLang 매 mainstream. ## 매 핵심 ### 매 JIT 의 mechanics - **Trace**: 매 input shape / dtype 의 capture 매 computational graph. - **Specialize**: 매 fixed shapes 의 specialized kernel 의 generate. - **Cache**: 매 (function, signature) → compiled artifact. - **Recompile**: 매 shape change → cache miss → recompile (avoid in hot loop). ### 매 vs AOT - **AOT (ahead-of-time)**: rustc, gcc — startup 빠름, 매 dynamic dispatch 부족. - **JIT**: 매 runtime info 의 use → better inlining, 매 startup 의 cost. - **Hybrid**: PyPy, V8, .NET — interpret first, JIT after N invocations. ### 매 ML JIT 의 specifics - **Static shape**: JAX `jit` 매 traced shape 의 specialize — dynamic shape 매 retrace. - **XLA / Triton backend**: 매 fused kernels — memory bandwidth dominant. - **Compilation cache**: persistent disk cache 매 cold-start 의 mitigate. ### 매 응용 1. ML training loop (JAX, torch.compile). 2. Numerical Python (Numba `@njit`). 3. JavaScript engines (V8, JSC). 4. Database query plans (Snowflake, DuckDB). ## 💻 패턴 ### Pattern 1: JAX jit (2026 standard) ```python import jax import jax.numpy as jnp @jax.jit def attention(q, k, v): scores = jnp.einsum("bhqd,bhkd->bhqk", q, k) / jnp.sqrt(q.shape[-1]) weights = jax.nn.softmax(scores, axis=-1) return jnp.einsum("bhqk,bhkd->bhqd", weights, v) # First call: trace + compile (slow) # Subsequent: cached (fast) out = attention(q, k, v) ``` ### Pattern 2: torch.compile (PyTorch 2.x) ```python import torch model = MyTransformer().cuda() compiled = torch.compile(model, mode="reduce-overhead", fullgraph=True) for batch in dataloader: out = compiled(batch) # 매 first batch 매 slow, subsequent 매 fast out.backward() ``` ### Pattern 3: Static argnums (avoid retrace) ```python from functools import partial @partial(jax.jit, static_argnums=(1,)) def topk(logits, k): return jax.lax.top_k(logits, k) # 매 k=10 매 specialized — 매 k=20 매 separate compilation topk(logits, 10) topk(logits, 20) # new compile ``` ### Pattern 4: Numba JIT (Python → LLVM) ```python from numba import njit import numpy as np @njit(cache=True, fastmath=True) def mandelbrot(c, max_iter=100): z = 0.0 + 0.0j for i in range(max_iter): z = z * z + c if z.real * z.real + z.imag * z.imag > 4.0: return i return max_iter ``` ### Pattern 5: AOT cache 의 prewarm ```python import os os.environ["JAX_COMPILATION_CACHE_DIR"] = "/var/cache/jax" import jax jax.config.update("jax_persistent_cache_min_entry_size_bytes", 0) jax.config.update("jax_persistent_cache_min_compile_time_secs", 1.0) # 매 first deployment 매 prewarm script 의 run — 매 next pods cold-start fast. ``` ### Pattern 6: Recompilation detection ```python import jax from collections import Counter class CompileCounter: def __init__(self): self.count = Counter() def trace(self, fn_name: str, sig: tuple): self.count[(fn_name, sig)] += 1 if self.count[(fn_name, sig)] > 3: print(f"매 thrash: {fn_name} recompiled {self.count[(fn_name, sig)]} times") # Usage: hook into jax.config or torch dynamo logger ``` ### Pattern 7: Mojo JIT (2026) ```mojo fn matmul[M: Int, N: Int, K: Int](a: Tensor, b: Tensor) -> Tensor: # 매 compile-time specialization 매 shapes — 매 SIMD auto-vectorize. var c = Tensor[DType.float32](M, N) for i in range(M): for j in range(N): var s: Float32 = 0 for k in range(K): s += a[i, k] * b[k, j] c[i, j] = s return c ``` ## 매 결정 기준 | 상황 | Approach | |---|---| | Numerical Python tight loop | Numba `@njit`. | | ML training | JAX `jit` 또는 `torch.compile`. | | Variable shapes | Avoid JIT 또는 `dynamic=True`. | | One-shot script | 매 JIT overhead 매 not worth. | | Long-running server | JIT + persistent cache. | **기본값**: ML 매 `torch.compile(mode="reduce-overhead")` 또는 `jax.jit`. Tight numerical loop 매 Numba. ## 🔗 Graph - 부모: [[Performance-Optimization]] - 변형: [[Tracing-JIT]] - 응용: [[JAX]] · [[torch.compile]] · [[V8 Engine]] - Adjacent: [[XLA]] · [[Triton]] ## 🤖 LLM 활용 **언제**: ML training/serving where compile cost amortizes (>100 calls), tight numerical loops, long-running services. **언제 X**: One-shot scripts, code with constantly-changing shapes, debugging (use eager mode). ## ❌ 안티패턴 - **JIT in hot Python loop with varying shapes**: 매 retrace 매 every call — slower than eager. - **No persistent cache**: 매 cold start 매 30s+ compile every deploy. - **JIT debugging**: 매 stacktrace 매 useless — eager 의 disable JIT first. - **Premature JIT**: profile first — 매 80% code 매 not bottleneck. ## 🧪 검증 / 중복 - Verified: JAX docs (2026), PyTorch 2.x docs, "Engineering a Compiler" (Cooper & Torczon), V8 design docs. - 신뢰도 A. ## 🕓 Changelog | 날짜 | 변경 | |---|---| | 2026-05-08 | Phase 1 | | 2026-05-10 | Manual cleanup — full content with JAX/torch.compile/Numba/Mojo 2026 patterns |