Files
2nd/10_Wiki/Topics/AI_and_ML/JIT-Compilation-in-AI-Engines.md
T
Antigravity Agent f8b21af4be Wiki cleanup: error-doc removal, dedup merge, link normalization
10_Wiki/Topics 대규모 정리:
- 오류 캡처/미완성 stub 문서 227개 제거
- 교차폴더 중복 43클러스터 병합 (63파일 → redirect)
- 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건
- 카테고리 MOC 6개 신규 생성
- Graph 섹션 미해결 related-keyword 링크 10,058건 제거

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 23:52:15 +09:00

252 lines
6.6 KiB
Markdown

---
id: wiki-2026-0508-jit-compilation-in-ai-engines
title: JIT Compilation in AI Engines
category: 10_Wiki/Topics
status: verified
canonical_id: self
aliases: [JIT, torch.compile, XLA, JAX, TensorRT, ONNX Runtime, kernel fusion]
duplicate_of: none
source_trust_level: A
confidence_score: 0.94
verification_status: applied
tags: [compilation, jit, torch-compile, xla, jax, tensorrt, optimization]
raw_sources: []
last_reinforced: 2026-05-10
github_commit: pending
tech_stack:
language: Python
framework: PyTorch / JAX / TensorRT
---
# JIT Compilation in AI Engines
## 매 한 줄
> **"매 runtime graph 의 의 의 fused / optimized kernel 의 의 의 compile"**. 매 modern: PyTorch 2.x torch.compile (TorchDynamo + Inductor), JAX (XLA), TensorRT, ONNX Runtime. 매 응용: 매 inference 의 2-5x speedup, 매 training 의 1.5-2x.
## 매 핵심
### 매 frameworks
- **PyTorch torch.compile** (2.0+): 매 TorchDynamo + Inductor.
- **JAX**: 매 XLA-based.
- **TensorRT** (NVIDIA): 매 inference-only, 매 quantization.
- **ONNX Runtime**: 매 cross-framework.
- **TVM**: 매 deep learning compiler.
- **Triton**: 매 Python kernel JIT.
### 매 optimizations
- **Kernel fusion**.
- **Constant folding**.
- **Dead code elimination**.
- **Memory planning**.
- **Quantization**.
- **Graph rewrite**.
### 매 응용
1. Inference acceleration.
2. Training speed.
3. Edge deployment.
4. Mobile.
## 💻 패턴
### torch.compile (PyTorch 2.x)
```python
import torch
model = MyModel().cuda()
compiled = torch.compile(model, mode='reduce-overhead')
out = compiled(x) # 매 first call: compile (slow), subsequent: fast
```
### Modes
```python
# 매 default
model = torch.compile(model)
# 매 reduce overhead (small model)
model = torch.compile(model, mode='reduce-overhead')
# 매 max performance
model = torch.compile(model, mode='max-autotune')
```
### JAX jit
```python
import jax
import jax.numpy as jnp
@jax.jit
def fn(x):
return jnp.sin(x) * jnp.cos(x)
# 매 first call: trace + compile, subsequent: fast
y = fn(jnp.ones(1000))
```
### JAX vmap + jit
```python
@jax.jit
@jax.vmap
def model_fn(x):
return x ** 2 + 1
batch_out = model_fn(batch_x) # 매 auto-batched
```
### TensorRT (inference)
```python
import torch
import torch_tensorrt
trt_model = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float16)],
enabled_precisions={torch.float16})
out = trt_model(x.half().cuda())
```
### ONNX export + Runtime
```python
torch.onnx.export(model, x, 'model.onnx', dynamic_axes={'input': {0: 'batch'}})
import onnxruntime as ort
sess = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])
out = sess.run(None, {'input': x.cpu().numpy()})
```
### Triton kernel (Python JIT)
```python
import triton
import triton.language as tl
@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK: tl.constexpr):
pid = tl.program_id(0)
offsets = pid * BLOCK + tl.arange(0, BLOCK)
mask = offsets < N
x = tl.load(x_ptr + offsets, mask=mask)
y = tl.load(y_ptr + offsets, mask=mask)
tl.store(out_ptr + offsets, x + y, mask=mask)
```
### Compile + flash attention
```python
# 매 PyTorch 2.x torch.compile + sdpa = Flash Attention auto
class Block(torch.nn.Module):
def forward(self, x):
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
compiled = torch.compile(model)
# 매 internally fuses attention
```
### Graph break detection
```python
# 매 torch.compile 의 graph break 의 의 의 perf X
# 매 torch._dynamo.config.verbose = True
# 매 → log graph breaks
import torch._dynamo
torch._dynamo.config.verbose = True
out = compiled(x) # 매 print breaks
```
### Recompilation issue (input shape change)
```python
# 매 dynamic shape 의 의 dynamic=True
compiled = torch.compile(model, dynamic=True)
# 매 매 첫 batch shape 가 fixed 매 매 의 매 recompile
```
### Benchmark
```python
import time
def benchmark(fn, x, n_warmup=10, n_iter=100):
for _ in range(n_warmup): fn(x); torch.cuda.synchronize()
t0 = time.perf_counter()
for _ in range(n_iter): fn(x); torch.cuda.synchronize()
return (time.perf_counter() - t0) / n_iter
t_eager = benchmark(model, x)
t_compiled = benchmark(compiled, x)
print(f'Speedup: {t_eager / t_compiled:.2f}x')
```
### CUDA Graph (extreme)
```python
# 매 매 fixed-shape inference 의 의 의 의 launch overhead 의 의 의 minimize
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
out = model(static_input)
# 매 each iteration
def step(x):
static_input.copy_(x)
g.replay()
return out.clone()
```
### Quantization (TRT)
```python
# 매 INT8 calibration
calib_data = [...]
trt_model = torch_tensorrt.compile(
model,
inputs=[torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.int8)],
enabled_precisions={torch.int8},
calibrator=torch_tensorrt.ptq.DataLoaderCalibrator(calib_data),
)
```
### TVM (cross-platform)
```python
import tvm
from tvm import relay
mod, params = relay.frontend.from_pytorch(model, [('input', x.shape)])
target = 'cuda'
with tvm.transform.PassContext(opt_level=3):
lib = relay.build(mod, target=target, params=params)
```
### Compile-time decision
```python
def should_compile(model, expected_calls):
if expected_calls > 100: return True # 매 amortize compile cost
if model.is_static_shape: return True
return False
```
## 매 결정 기준
| 상황 | Tool |
|---|---|
| Default training | torch.compile |
| JAX research | jax.jit |
| Production inference | TensorRT |
| Cross-platform | ONNX Runtime |
| Custom kernel | Triton |
| Edge | TVM / TFLite |
| Repeated inference | + CUDA Graph |
**기본값**: 매 PyTorch 2.x = torch.compile + sdpa (auto Flash). 매 production = TensorRT. 매 cross = ONNX Runtime.
## 🔗 Graph
- 변형: [[torch.compile]] · [[TensorRT]] · [[Triton]]
- 응용: [[Flash Attention]] · [[GPU-Programming-with-CUDA]] · [[LLM_Optimization_and_Deployment_Strategies|Quantization]]
- Adjacent: [[Foundation-Models]] · [[Edge-AI-and-Computing]]
## 🤖 LLM 활용
**언제**: 매 production. 매 large model. 매 inference 가속.
**언제 X**: 매 dynamic shape changes a lot. 매 prototype.
## ❌ 안티패턴
- **Compile every call**: 매 amortize 의 fail.
- **Ignore graph breaks**: 매 perf flat.
- **Recompile每次 shape change**: 매 dynamic=True 의 forget.
- **No warmup in benchmark**: 매 misleading.
## 🧪 검증 / 중복
- Verified (PyTorch 2.0 docs, JAX docs, TensorRT, Triton).
- 신뢰도 A.
## 🕓 Changelog
| 날짜 | 변경 |
|---|---|
| 2026-05-08 | Phase 1 |
| 2026-05-10 | Manual cleanup — JIT + 매 torch.compile / JAX / TRT / ONNX / Triton / TVM code |