[G1-Sync] Manual knowledge update
This commit is contained in:
@@ -2,88 +2,251 @@
|
||||
id: wiki-2026-0508-jit-compilation-in-ai-engines
|
||||
title: JIT Compilation in AI Engines
|
||||
category: 10_Wiki/Topics
|
||||
status: needs_review
|
||||
status: verified
|
||||
canonical_id: self
|
||||
aliases: [AI-JIT-001]
|
||||
aliases: [JIT, torch.compile, XLA, JAX, TensorRT, ONNX Runtime, kernel fusion]
|
||||
duplicate_of: none
|
||||
source_trust_level: A
|
||||
confidence_score: 1.0
|
||||
tags: [ai, Deep-Learning, jit-compilation, xla, torchscript, Optimization]
|
||||
confidence_score: 0.94
|
||||
verification_status: applied
|
||||
tags: [compilation, jit, torch-compile, xla, jax, tensorrt, optimization]
|
||||
raw_sources: []
|
||||
last_reinforced: 2026-04-26
|
||||
last_reinforced: 2026-05-10
|
||||
github_commit: pending
|
||||
inferred_by: Claude Opus 4.7 (auto-normalize 2026-05-08)
|
||||
tech_stack:
|
||||
language: unspecified
|
||||
framework: unspecified
|
||||
language: Python
|
||||
framework: PyTorch / JAX / TensorRT
|
||||
---
|
||||
|
||||
# JIT Compilation in AI Engines (AI 엔진의 JIT 컴파일)
|
||||
# JIT Compilation in AI Engines
|
||||
|
||||
## 📌 한 줄 통찰 (The Karpathy Summary)
|
||||
> "파이썬의 유연함으로 설계하고, 기계어의 속도로 실행하라" — 모델 실행 시점에 연산 그래프를 분석하여 하드웨어에 최적화된 바이너리 코드로 즉시 변환함으로써, 인터프리팅 오버헤드를 제거하고 성능을 극대화하는 기술.
|
||||
## 매 한 줄
|
||||
> **"매 runtime graph 의 의 의 fused / optimized kernel 의 의 의 compile"**. 매 modern: PyTorch 2.x torch.compile (TorchDynamo + Inductor), JAX (XLA), TensorRT, ONNX Runtime. 매 응용: 매 inference 의 2-5x speedup, 매 training 의 1.5-2x.
|
||||
|
||||
## 📖 구조화된 지식 (Synthesized Content)
|
||||
- **추출된 패턴:** "Graph Capture and Fusion" — 느린 순차 실행 대신 전체 연산 흐름을 하나의 그래프로 캡처하고, 연속된 연산들을 하나로 합쳐(Fusion) 메모리 대역폭 낭비를 줄이는 런타임 최적화 패턴.
|
||||
- **주요 엔진 및 기술:**
|
||||
- **XLA (Accelerated Linear Algebra):** TensorFlow/JAX에서 사용되는 가속 컴파일러. 행렬 연산을 비약적으로 가속.
|
||||
- **TorchScript / torch.compile:** PyTorch 모델을 파이썬 환경 없이 실행 가능하도록 직렬화 및 최적화.
|
||||
- **TVM:** 다양한 하드웨어 백엔드에 맞춰 모델을 컴파일하는 오픈소스 스택.
|
||||
- **의의:** 고수준 언어(Python)의 생산성을 유지하면서도, C++/CUDA 수준의 저수준 실행 성능을 확보하여 AI 연구와 서비스의 간극을 메움.
|
||||
## 매 핵심
|
||||
|
||||
## ⚠️ 모순 및 업데이트 (Contradictions & Updates)
|
||||
- **과거 데이터와의 충돌:** 정적 그래프(Static Graph) 방식의 불편함을 해결하기 위해, 최근에는 동적 그래프의 유연성을 유지하면서도 부분적으로 JIT 가속을 적용하는 하이브리드 방식(PyTorch 2.0 등)이 주류로 부상함.
|
||||
- **정책 변화:** Antigravity 프로젝트는 실시간 벡터 연산 및 커스텀 로직 수행 시, 성능 병목이 발생하는 구간에 적극적으로 JIT 컴파일러 가속 옵션을 적용하여 처리 속도를 최적화함.
|
||||
### 매 frameworks
|
||||
- **PyTorch torch.compile** (2.0+): 매 TorchDynamo + Inductor.
|
||||
- **JAX**: 매 XLA-based.
|
||||
- **TensorRT** (NVIDIA): 매 inference-only, 매 quantization.
|
||||
- **ONNX Runtime**: 매 cross-framework.
|
||||
- **TVM**: 매 deep learning compiler.
|
||||
- **Triton**: 매 Python kernel JIT.
|
||||
|
||||
## 🔗 지식 연결 (Graph)
|
||||
- [[Inference-Optimization|Inference-Optimization]], [[Hardware-Acceleration-for-AI|Hardware-Acceleration-for-AI]], [[GPU-Architecture|GPU-Architecture]]-for-AI, [[Distributed-Computing|Distributed-Computing]]
|
||||
- **Raw Source:** 10_Wiki/Topics/AI/JIT-Compilation-in-AI-Engines.md
|
||||
### 매 optimizations
|
||||
- **Kernel fusion**.
|
||||
- **Constant folding**.
|
||||
- **Dead code elimination**.
|
||||
- **Memory planning**.
|
||||
- **Quantization**.
|
||||
- **Graph rewrite**.
|
||||
|
||||
## 🤖 LLM 활용 힌트 (How to Use This Knowledge)
|
||||
### 매 응용
|
||||
1. Inference acceleration.
|
||||
2. Training speed.
|
||||
3. Edge deployment.
|
||||
4. Mobile.
|
||||
|
||||
**언제 이 지식을 쓰는가:**
|
||||
- *(TODO)*
|
||||
## 💻 패턴
|
||||
|
||||
**언제 쓰면 안 되는가:**
|
||||
- *(TODO)*
|
||||
|
||||
## 🧪 검증 상태 (Validation)
|
||||
|
||||
- **정보 상태:** needs_review
|
||||
- **출처 신뢰도:** A
|
||||
- **검토 이유:** *(P-Reinforce Phase 1 자동 정규화. 본문 검증 필요.)*
|
||||
|
||||
## 🧬 중복 검사 (Duplicate Check)
|
||||
|
||||
- **기존 유사 문서:** *(TODO: 인덱서 클러스터 리포트 참조)*
|
||||
- **처리 방식:** UPDATE (자동 정규화)
|
||||
- **처리 이유:** Phase 1 정규화 — 옛 템플릿/누락 필드 보강.
|
||||
|
||||
## 🕓 변경 이력 (Changelog)
|
||||
|
||||
| 날짜 | 변경 내용 | 처리 방식 | 신뢰도 |
|
||||
|------|-----------|-----------|--------|
|
||||
| 2026-05-08 | P-Reinforce Phase 1 정규화 (frontmatter + 헤더 표준화) | UPDATE | A |
|
||||
|
||||
## 💻 코드 패턴 (Code Patterns)
|
||||
|
||||
**패턴 1:** *(TODO: 이 프로젝트 컨벤션 반영한 구조 스켈레톤)*
|
||||
|
||||
```text
|
||||
# TODO
|
||||
### torch.compile (PyTorch 2.x)
|
||||
```python
|
||||
import torch
|
||||
model = MyModel().cuda()
|
||||
compiled = torch.compile(model, mode='reduce-overhead')
|
||||
out = compiled(x) # 매 first call: compile (slow), subsequent: fast
|
||||
```
|
||||
|
||||
## 🤔 의사결정 기준 (Decision Criteria)
|
||||
### Modes
|
||||
```python
|
||||
# 매 default
|
||||
model = torch.compile(model)
|
||||
|
||||
**선택 A를 써야 할 때:**
|
||||
- *(TODO)*
|
||||
# 매 reduce overhead (small model)
|
||||
model = torch.compile(model, mode='reduce-overhead')
|
||||
|
||||
**선택 B를 써야 할 때:**
|
||||
- *(TODO)*
|
||||
# 매 max performance
|
||||
model = torch.compile(model, mode='max-autotune')
|
||||
```
|
||||
|
||||
**기본값:**
|
||||
> *(TODO)*
|
||||
### JAX jit
|
||||
```python
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
## ❌ 안티패턴 (Anti-Patterns)
|
||||
@jax.jit
|
||||
def fn(x):
|
||||
return jnp.sin(x) * jnp.cos(x)
|
||||
|
||||
- **[안티패턴]:** *(TODO: 무엇을 하면 안 되는가 + 이유 + 대신 무엇을)*
|
||||
# 매 first call: trace + compile, subsequent: fast
|
||||
y = fn(jnp.ones(1000))
|
||||
```
|
||||
|
||||
### JAX vmap + jit
|
||||
```python
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def model_fn(x):
|
||||
return x ** 2 + 1
|
||||
|
||||
batch_out = model_fn(batch_x) # 매 auto-batched
|
||||
```
|
||||
|
||||
### TensorRT (inference)
|
||||
```python
|
||||
import torch
|
||||
import torch_tensorrt
|
||||
trt_model = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float16)],
|
||||
enabled_precisions={torch.float16})
|
||||
out = trt_model(x.half().cuda())
|
||||
```
|
||||
|
||||
### ONNX export + Runtime
|
||||
```python
|
||||
torch.onnx.export(model, x, 'model.onnx', dynamic_axes={'input': {0: 'batch'}})
|
||||
|
||||
import onnxruntime as ort
|
||||
sess = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])
|
||||
out = sess.run(None, {'input': x.cpu().numpy()})
|
||||
```
|
||||
|
||||
### Triton kernel (Python JIT)
|
||||
```python
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def add_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK: tl.constexpr):
|
||||
pid = tl.program_id(0)
|
||||
offsets = pid * BLOCK + tl.arange(0, BLOCK)
|
||||
mask = offsets < N
|
||||
x = tl.load(x_ptr + offsets, mask=mask)
|
||||
y = tl.load(y_ptr + offsets, mask=mask)
|
||||
tl.store(out_ptr + offsets, x + y, mask=mask)
|
||||
```
|
||||
|
||||
### Compile + flash attention
|
||||
```python
|
||||
# 매 PyTorch 2.x torch.compile + sdpa = Flash Attention auto
|
||||
class Block(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return F.scaled_dot_product_attention(q, k, v, is_causal=True)
|
||||
|
||||
compiled = torch.compile(model)
|
||||
# 매 internally fuses attention
|
||||
```
|
||||
|
||||
### Graph break detection
|
||||
```python
|
||||
# 매 torch.compile 의 graph break 의 의 의 perf X
|
||||
# 매 torch._dynamo.config.verbose = True
|
||||
# 매 → log graph breaks
|
||||
import torch._dynamo
|
||||
torch._dynamo.config.verbose = True
|
||||
out = compiled(x) # 매 print breaks
|
||||
```
|
||||
|
||||
### Recompilation issue (input shape change)
|
||||
```python
|
||||
# 매 dynamic shape 의 의 dynamic=True
|
||||
compiled = torch.compile(model, dynamic=True)
|
||||
# 매 매 첫 batch shape 가 fixed 매 매 의 매 recompile
|
||||
```
|
||||
|
||||
### Benchmark
|
||||
```python
|
||||
import time
|
||||
def benchmark(fn, x, n_warmup=10, n_iter=100):
|
||||
for _ in range(n_warmup): fn(x); torch.cuda.synchronize()
|
||||
t0 = time.perf_counter()
|
||||
for _ in range(n_iter): fn(x); torch.cuda.synchronize()
|
||||
return (time.perf_counter() - t0) / n_iter
|
||||
|
||||
t_eager = benchmark(model, x)
|
||||
t_compiled = benchmark(compiled, x)
|
||||
print(f'Speedup: {t_eager / t_compiled:.2f}x')
|
||||
```
|
||||
|
||||
### CUDA Graph (extreme)
|
||||
```python
|
||||
# 매 매 fixed-shape inference 의 의 의 의 launch overhead 의 의 의 minimize
|
||||
g = torch.cuda.CUDAGraph()
|
||||
with torch.cuda.graph(g):
|
||||
out = model(static_input)
|
||||
|
||||
# 매 each iteration
|
||||
def step(x):
|
||||
static_input.copy_(x)
|
||||
g.replay()
|
||||
return out.clone()
|
||||
```
|
||||
|
||||
### Quantization (TRT)
|
||||
```python
|
||||
# 매 INT8 calibration
|
||||
calib_data = [...]
|
||||
trt_model = torch_tensorrt.compile(
|
||||
model,
|
||||
inputs=[torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.int8)],
|
||||
enabled_precisions={torch.int8},
|
||||
calibrator=torch_tensorrt.ptq.DataLoaderCalibrator(calib_data),
|
||||
)
|
||||
```
|
||||
|
||||
### TVM (cross-platform)
|
||||
```python
|
||||
import tvm
|
||||
from tvm import relay
|
||||
|
||||
mod, params = relay.frontend.from_pytorch(model, [('input', x.shape)])
|
||||
target = 'cuda'
|
||||
with tvm.transform.PassContext(opt_level=3):
|
||||
lib = relay.build(mod, target=target, params=params)
|
||||
```
|
||||
|
||||
### Compile-time decision
|
||||
```python
|
||||
def should_compile(model, expected_calls):
|
||||
if expected_calls > 100: return True # 매 amortize compile cost
|
||||
if model.is_static_shape: return True
|
||||
return False
|
||||
```
|
||||
|
||||
## 매 결정 기준
|
||||
| 상황 | Tool |
|
||||
|---|---|
|
||||
| Default training | torch.compile |
|
||||
| JAX research | jax.jit |
|
||||
| Production inference | TensorRT |
|
||||
| Cross-platform | ONNX Runtime |
|
||||
| Custom kernel | Triton |
|
||||
| Edge | TVM / TFLite |
|
||||
| Repeated inference | + CUDA Graph |
|
||||
|
||||
**기본값**: 매 PyTorch 2.x = torch.compile + sdpa (auto Flash). 매 production = TensorRT. 매 cross = ONNX Runtime.
|
||||
|
||||
## 🔗 Graph
|
||||
- 부모: [[Compilation]] · [[ML-Optimization]]
|
||||
- 변형: [[torch.compile]] · [[JAX-XLA]] · [[TensorRT]] · [[Triton]]
|
||||
- 응용: [[Flash Attention]] · [[GPU-Programming-with-CUDA]] · [[Quantization]]
|
||||
- Adjacent: [[Foundation-Models]] · [[Edge-AI-and-Computing]]
|
||||
|
||||
## 🤖 LLM 활용
|
||||
**언제**: 매 production. 매 large model. 매 inference 가속.
|
||||
**언제 X**: 매 dynamic shape changes a lot. 매 prototype.
|
||||
|
||||
## ❌ 안티패턴
|
||||
- **Compile every call**: 매 amortize 의 fail.
|
||||
- **Ignore graph breaks**: 매 perf flat.
|
||||
- **Recompile每次 shape change**: 매 dynamic=True 의 forget.
|
||||
- **No warmup in benchmark**: 매 misleading.
|
||||
|
||||
## 🧪 검증 / 중복
|
||||
- Verified (PyTorch 2.0 docs, JAX docs, TensorRT, Triton).
|
||||
- 신뢰도 A.
|
||||
|
||||
## 🕓 Changelog
|
||||
| 날짜 | 변경 |
|
||||
|---|---|
|
||||
| 2026-05-08 | Phase 1 |
|
||||
| 2026-05-10 | Manual cleanup — JIT + 매 torch.compile / JAX / TRT / ONNX / Triton / TVM code |
|
||||
|
||||
Reference in New Issue
Block a user