f8b21af4be
10_Wiki/Topics 대규모 정리: - 오류 캡처/미완성 stub 문서 227개 제거 - 교차폴더 중복 43클러스터 병합 (63파일 → redirect) - 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건 - 카테고리 MOC 6개 신규 생성 - Graph 섹션 미해결 related-keyword 링크 10,058건 제거 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
252 lines
6.6 KiB
Markdown
252 lines
6.6 KiB
Markdown
---
|
|
id: wiki-2026-0508-jit-compilation-in-ai-engines
|
|
title: JIT Compilation in AI Engines
|
|
category: 10_Wiki/Topics
|
|
status: verified
|
|
canonical_id: self
|
|
aliases: [JIT, torch.compile, XLA, JAX, TensorRT, ONNX Runtime, kernel fusion]
|
|
duplicate_of: none
|
|
source_trust_level: A
|
|
confidence_score: 0.94
|
|
verification_status: applied
|
|
tags: [compilation, jit, torch-compile, xla, jax, tensorrt, optimization]
|
|
raw_sources: []
|
|
last_reinforced: 2026-05-10
|
|
github_commit: pending
|
|
tech_stack:
|
|
language: Python
|
|
framework: PyTorch / JAX / TensorRT
|
|
---
|
|
|
|
# JIT Compilation in AI Engines
|
|
|
|
## 매 한 줄
|
|
> **"매 runtime graph 의 의 의 fused / optimized kernel 의 의 의 compile"**. 매 modern: PyTorch 2.x torch.compile (TorchDynamo + Inductor), JAX (XLA), TensorRT, ONNX Runtime. 매 응용: 매 inference 의 2-5x speedup, 매 training 의 1.5-2x.
|
|
|
|
## 매 핵심
|
|
|
|
### 매 frameworks
|
|
- **PyTorch torch.compile** (2.0+): 매 TorchDynamo + Inductor.
|
|
- **JAX**: 매 XLA-based.
|
|
- **TensorRT** (NVIDIA): 매 inference-only, 매 quantization.
|
|
- **ONNX Runtime**: 매 cross-framework.
|
|
- **TVM**: 매 deep learning compiler.
|
|
- **Triton**: 매 Python kernel JIT.
|
|
|
|
### 매 optimizations
|
|
- **Kernel fusion**.
|
|
- **Constant folding**.
|
|
- **Dead code elimination**.
|
|
- **Memory planning**.
|
|
- **Quantization**.
|
|
- **Graph rewrite**.
|
|
|
|
### 매 응용
|
|
1. Inference acceleration.
|
|
2. Training speed.
|
|
3. Edge deployment.
|
|
4. Mobile.
|
|
|
|
## 💻 패턴
|
|
|
|
### torch.compile (PyTorch 2.x)
|
|
```python
|
|
import torch
|
|
model = MyModel().cuda()
|
|
compiled = torch.compile(model, mode='reduce-overhead')
|
|
out = compiled(x) # 매 first call: compile (slow), subsequent: fast
|
|
```
|
|
|
|
### Modes
|
|
```python
|
|
# 매 default
|
|
model = torch.compile(model)
|
|
|
|
# 매 reduce overhead (small model)
|
|
model = torch.compile(model, mode='reduce-overhead')
|
|
|
|
# 매 max performance
|
|
model = torch.compile(model, mode='max-autotune')
|
|
```
|
|
|
|
### JAX jit
|
|
```python
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
@jax.jit
|
|
def fn(x):
|
|
return jnp.sin(x) * jnp.cos(x)
|
|
|
|
# 매 first call: trace + compile, subsequent: fast
|
|
y = fn(jnp.ones(1000))
|
|
```
|
|
|
|
### JAX vmap + jit
|
|
```python
|
|
@jax.jit
|
|
@jax.vmap
|
|
def model_fn(x):
|
|
return x ** 2 + 1
|
|
|
|
batch_out = model_fn(batch_x) # 매 auto-batched
|
|
```
|
|
|
|
### TensorRT (inference)
|
|
```python
|
|
import torch
|
|
import torch_tensorrt
|
|
trt_model = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float16)],
|
|
enabled_precisions={torch.float16})
|
|
out = trt_model(x.half().cuda())
|
|
```
|
|
|
|
### ONNX export + Runtime
|
|
```python
|
|
torch.onnx.export(model, x, 'model.onnx', dynamic_axes={'input': {0: 'batch'}})
|
|
|
|
import onnxruntime as ort
|
|
sess = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])
|
|
out = sess.run(None, {'input': x.cpu().numpy()})
|
|
```
|
|
|
|
### Triton kernel (Python JIT)
|
|
```python
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
@triton.jit
|
|
def add_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK: tl.constexpr):
|
|
pid = tl.program_id(0)
|
|
offsets = pid * BLOCK + tl.arange(0, BLOCK)
|
|
mask = offsets < N
|
|
x = tl.load(x_ptr + offsets, mask=mask)
|
|
y = tl.load(y_ptr + offsets, mask=mask)
|
|
tl.store(out_ptr + offsets, x + y, mask=mask)
|
|
```
|
|
|
|
### Compile + flash attention
|
|
```python
|
|
# 매 PyTorch 2.x torch.compile + sdpa = Flash Attention auto
|
|
class Block(torch.nn.Module):
|
|
def forward(self, x):
|
|
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
|
|
|
compiled = torch.compile(model)
|
|
# 매 internally fuses attention
|
|
```
|
|
|
|
### Graph break detection
|
|
```python
|
|
# 매 torch.compile 의 graph break 의 의 의 perf X
|
|
# 매 torch._dynamo.config.verbose = True
|
|
# 매 → log graph breaks
|
|
import torch._dynamo
|
|
torch._dynamo.config.verbose = True
|
|
out = compiled(x) # 매 print breaks
|
|
```
|
|
|
|
### Recompilation issue (input shape change)
|
|
```python
|
|
# 매 dynamic shape 의 의 dynamic=True
|
|
compiled = torch.compile(model, dynamic=True)
|
|
# 매 매 첫 batch shape 가 fixed 매 매 의 매 recompile
|
|
```
|
|
|
|
### Benchmark
|
|
```python
|
|
import time
|
|
def benchmark(fn, x, n_warmup=10, n_iter=100):
|
|
for _ in range(n_warmup): fn(x); torch.cuda.synchronize()
|
|
t0 = time.perf_counter()
|
|
for _ in range(n_iter): fn(x); torch.cuda.synchronize()
|
|
return (time.perf_counter() - t0) / n_iter
|
|
|
|
t_eager = benchmark(model, x)
|
|
t_compiled = benchmark(compiled, x)
|
|
print(f'Speedup: {t_eager / t_compiled:.2f}x')
|
|
```
|
|
|
|
### CUDA Graph (extreme)
|
|
```python
|
|
# 매 매 fixed-shape inference 의 의 의 의 launch overhead 의 의 의 minimize
|
|
g = torch.cuda.CUDAGraph()
|
|
with torch.cuda.graph(g):
|
|
out = model(static_input)
|
|
|
|
# 매 each iteration
|
|
def step(x):
|
|
static_input.copy_(x)
|
|
g.replay()
|
|
return out.clone()
|
|
```
|
|
|
|
### Quantization (TRT)
|
|
```python
|
|
# 매 INT8 calibration
|
|
calib_data = [...]
|
|
trt_model = torch_tensorrt.compile(
|
|
model,
|
|
inputs=[torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.int8)],
|
|
enabled_precisions={torch.int8},
|
|
calibrator=torch_tensorrt.ptq.DataLoaderCalibrator(calib_data),
|
|
)
|
|
```
|
|
|
|
### TVM (cross-platform)
|
|
```python
|
|
import tvm
|
|
from tvm import relay
|
|
|
|
mod, params = relay.frontend.from_pytorch(model, [('input', x.shape)])
|
|
target = 'cuda'
|
|
with tvm.transform.PassContext(opt_level=3):
|
|
lib = relay.build(mod, target=target, params=params)
|
|
```
|
|
|
|
### Compile-time decision
|
|
```python
|
|
def should_compile(model, expected_calls):
|
|
if expected_calls > 100: return True # 매 amortize compile cost
|
|
if model.is_static_shape: return True
|
|
return False
|
|
```
|
|
|
|
## 매 결정 기준
|
|
| 상황 | Tool |
|
|
|---|---|
|
|
| Default training | torch.compile |
|
|
| JAX research | jax.jit |
|
|
| Production inference | TensorRT |
|
|
| Cross-platform | ONNX Runtime |
|
|
| Custom kernel | Triton |
|
|
| Edge | TVM / TFLite |
|
|
| Repeated inference | + CUDA Graph |
|
|
|
|
**기본값**: 매 PyTorch 2.x = torch.compile + sdpa (auto Flash). 매 production = TensorRT. 매 cross = ONNX Runtime.
|
|
|
|
## 🔗 Graph
|
|
- 변형: [[torch.compile]] · [[TensorRT]] · [[Triton]]
|
|
- 응용: [[Flash Attention]] · [[GPU-Programming-with-CUDA]] · [[LLM_Optimization_and_Deployment_Strategies|Quantization]]
|
|
- Adjacent: [[Foundation-Models]] · [[Edge-AI-and-Computing]]
|
|
|
|
## 🤖 LLM 활용
|
|
**언제**: 매 production. 매 large model. 매 inference 가속.
|
|
**언제 X**: 매 dynamic shape changes a lot. 매 prototype.
|
|
|
|
## ❌ 안티패턴
|
|
- **Compile every call**: 매 amortize 의 fail.
|
|
- **Ignore graph breaks**: 매 perf flat.
|
|
- **Recompile每次 shape change**: 매 dynamic=True 의 forget.
|
|
- **No warmup in benchmark**: 매 misleading.
|
|
|
|
## 🧪 검증 / 중복
|
|
- Verified (PyTorch 2.0 docs, JAX docs, TensorRT, Triton).
|
|
- 신뢰도 A.
|
|
|
|
## 🕓 Changelog
|
|
| 날짜 | 변경 |
|
|
|---|---|
|
|
| 2026-05-08 | Phase 1 |
|
|
| 2026-05-10 | Manual cleanup — JIT + 매 torch.compile / JAX / TRT / ONNX / Triton / TVM code |
|