Files
2nd/10_Wiki/Topics/AI_and_ML/JIT-Compilation-in-AI-Engines.md
T
Antigravity Agent f8b21af4be Wiki cleanup: error-doc removal, dedup merge, link normalization
10_Wiki/Topics 대규모 정리:
- 오류 캡처/미완성 stub 문서 227개 제거
- 교차폴더 중복 43클러스터 병합 (63파일 → redirect)
- 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건
- 카테고리 MOC 6개 신규 생성
- Graph 섹션 미해결 related-keyword 링크 10,058건 제거

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 23:52:15 +09:00

6.6 KiB

id, title, category, status, canonical_id, aliases, duplicate_of, source_trust_level, confidence_score, verification_status, tags, raw_sources, last_reinforced, github_commit, tech_stack
id title category status canonical_id aliases duplicate_of source_trust_level confidence_score verification_status tags raw_sources last_reinforced github_commit tech_stack
wiki-2026-0508-jit-compilation-in-ai-engines JIT Compilation in AI Engines 10_Wiki/Topics verified self
JIT
torch.compile
XLA
JAX
TensorRT
ONNX Runtime
kernel fusion
none A 0.94 applied
compilation
jit
torch-compile
xla
jax
tensorrt
optimization
2026-05-10 pending
language framework
Python PyTorch / JAX / TensorRT

JIT Compilation in AI Engines

매 한 줄

"매 runtime graph 의 의 의 fused / optimized kernel 의 의 의 compile". 매 modern: PyTorch 2.x torch.compile (TorchDynamo + Inductor), JAX (XLA), TensorRT, ONNX Runtime. 매 응용: 매 inference 의 2-5x speedup, 매 training 의 1.5-2x.

매 핵심

매 frameworks

  • PyTorch torch.compile (2.0+): 매 TorchDynamo + Inductor.
  • JAX: 매 XLA-based.
  • TensorRT (NVIDIA): 매 inference-only, 매 quantization.
  • ONNX Runtime: 매 cross-framework.
  • TVM: 매 deep learning compiler.
  • Triton: 매 Python kernel JIT.

매 optimizations

  • Kernel fusion.
  • Constant folding.
  • Dead code elimination.
  • Memory planning.
  • Quantization.
  • Graph rewrite.

매 응용

  1. Inference acceleration.
  2. Training speed.
  3. Edge deployment.
  4. Mobile.

💻 패턴

torch.compile (PyTorch 2.x)

import torch
model = MyModel().cuda()
compiled = torch.compile(model, mode='reduce-overhead')
out = compiled(x)  # 매 first call: compile (slow), subsequent: fast

Modes

# 매 default
model = torch.compile(model)

# 매 reduce overhead (small model)
model = torch.compile(model, mode='reduce-overhead')

# 매 max performance
model = torch.compile(model, mode='max-autotune')

JAX jit

import jax
import jax.numpy as jnp

@jax.jit
def fn(x):
    return jnp.sin(x) * jnp.cos(x)

# 매 first call: trace + compile, subsequent: fast
y = fn(jnp.ones(1000))

JAX vmap + jit

@jax.jit
@jax.vmap
def model_fn(x):
    return x ** 2 + 1

batch_out = model_fn(batch_x)  # 매 auto-batched

TensorRT (inference)

import torch
import torch_tensorrt
trt_model = torch_tensorrt.compile(model, inputs=[torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.float16)],
                                    enabled_precisions={torch.float16})
out = trt_model(x.half().cuda())

ONNX export + Runtime

torch.onnx.export(model, x, 'model.onnx', dynamic_axes={'input': {0: 'batch'}})

import onnxruntime as ort
sess = ort.InferenceSession('model.onnx', providers=['CUDAExecutionProvider'])
out = sess.run(None, {'input': x.cpu().numpy()})

Triton kernel (Python JIT)

import triton
import triton.language as tl

@triton.jit
def add_kernel(x_ptr, y_ptr, out_ptr, N, BLOCK: tl.constexpr):
    pid = tl.program_id(0)
    offsets = pid * BLOCK + tl.arange(0, BLOCK)
    mask = offsets < N
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    tl.store(out_ptr + offsets, x + y, mask=mask)

Compile + flash attention

# 매 PyTorch 2.x torch.compile + sdpa = Flash Attention auto
class Block(torch.nn.Module):
    def forward(self, x):
        return F.scaled_dot_product_attention(q, k, v, is_causal=True)

compiled = torch.compile(model)
# 매 internally fuses attention

Graph break detection

# 매 torch.compile 의 graph break 의 의 의 perf X
# 매 torch._dynamo.config.verbose = True
# 매 → log graph breaks
import torch._dynamo
torch._dynamo.config.verbose = True
out = compiled(x)  # 매 print breaks

Recompilation issue (input shape change)

# 매 dynamic shape 의 의 dynamic=True
compiled = torch.compile(model, dynamic=True)
# 매 매 첫 batch shape 가 fixed 매 매 의 매 recompile

Benchmark

import time
def benchmark(fn, x, n_warmup=10, n_iter=100):
    for _ in range(n_warmup): fn(x); torch.cuda.synchronize()
    t0 = time.perf_counter()
    for _ in range(n_iter): fn(x); torch.cuda.synchronize()
    return (time.perf_counter() - t0) / n_iter

t_eager = benchmark(model, x)
t_compiled = benchmark(compiled, x)
print(f'Speedup: {t_eager / t_compiled:.2f}x')

CUDA Graph (extreme)

# 매 매 fixed-shape inference 의 의 의 의 launch overhead 의 의 의 minimize
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
    out = model(static_input)

# 매 each iteration
def step(x):
    static_input.copy_(x)
    g.replay()
    return out.clone()

Quantization (TRT)

# 매 INT8 calibration
calib_data = [...]
trt_model = torch_tensorrt.compile(
    model,
    inputs=[torch_tensorrt.Input((1, 3, 224, 224), dtype=torch.int8)],
    enabled_precisions={torch.int8},
    calibrator=torch_tensorrt.ptq.DataLoaderCalibrator(calib_data),
)

TVM (cross-platform)

import tvm
from tvm import relay

mod, params = relay.frontend.from_pytorch(model, [('input', x.shape)])
target = 'cuda'
with tvm.transform.PassContext(opt_level=3):
    lib = relay.build(mod, target=target, params=params)

Compile-time decision

def should_compile(model, expected_calls):
    if expected_calls > 100: return True  # 매 amortize compile cost
    if model.is_static_shape: return True
    return False

매 결정 기준

상황 Tool
Default training torch.compile
JAX research jax.jit
Production inference TensorRT
Cross-platform ONNX Runtime
Custom kernel Triton
Edge TVM / TFLite
Repeated inference + CUDA Graph

기본값: 매 PyTorch 2.x = torch.compile + sdpa (auto Flash). 매 production = TensorRT. 매 cross = ONNX Runtime.

🔗 Graph

🤖 LLM 활용

언제: 매 production. 매 large model. 매 inference 가속. 언제 X: 매 dynamic shape changes a lot. 매 prototype.

안티패턴

  • Compile every call: 매 amortize 의 fail.
  • Ignore graph breaks: 매 perf flat.
  • Recompile每次 shape change: 매 dynamic=True 의 forget.
  • No warmup in benchmark: 매 misleading.

🧪 검증 / 중복

  • Verified (PyTorch 2.0 docs, JAX docs, TensorRT, Triton).
  • 신뢰도 A.

🕓 Changelog

날짜 변경
2026-05-08 Phase 1
2026-05-10 Manual cleanup — JIT + 매 torch.compile / JAX / TRT / ONNX / Triton / TVM code