d8a80f6272
이름만 다른(표기 변형) [[위키링크]]를 대상 문서의 canonical 제목으로 치환해 끊겼던 1,200개 링크를 연결. 제목/파일명 정규화 일치만 적용하고 별칭 매칭은 과병합 위험으로 제외(애매성 가드). 원본은 _link_reconcile_backup/ 에 백업. 도구: Datacollect/scripts/link_reconcile_apply.mjs Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
296 lines
8.1 KiB
Markdown
296 lines
8.1 KiB
Markdown
---
|
|
id: wiki-2026-0508-differentiable-programming
|
|
title: Differentiable Programming
|
|
category: 10_Wiki/Topics
|
|
status: verified
|
|
canonical_id: self
|
|
aliases: [diff prog, software 2.0, Karpathy, JAX, autograd, end-to-end optimization, Gumbel-Softmax]
|
|
duplicate_of: none
|
|
source_trust_level: A
|
|
confidence_score: 0.88
|
|
verification_status: applied
|
|
tags: [differentiable-programming, software-2-0, jax, pytorch, autograd, gradient, karpathy, neural-symbolic]
|
|
raw_sources: []
|
|
last_reinforced: 2026-05-10
|
|
github_commit: pending
|
|
tech_stack:
|
|
language: Python
|
|
framework: JAX / PyTorch / TensorFlow
|
|
---
|
|
|
|
# Differentiable Programming
|
|
|
|
## 매 한 줄
|
|
> **"매 program 자체 의 learnable"**. Andrej Karpathy 의 Software 2.0 (2017). 매 hand-coded logic → 매 gradient-optimized weight. 매 modern: 매 JAX 의 functional + 매 differentiable physics + 매 differentiable rendering. 매 LLM 의 ultimate Software 2.0.
|
|
|
|
## 매 핵심
|
|
|
|
### Software 1.0 vs 2.0 (Karpathy)
|
|
| 측면 | 1.0 | 2.0 |
|
|
|---|---|---|
|
|
| Code | 매 human-written | 매 gradient-found |
|
|
| Search | 매 deterministic | 매 stochastic |
|
|
| Modification | 매 manual | 매 train |
|
|
| Examples | algorithm, business rule | NN |
|
|
|
|
### 매 modern paradigm
|
|
- **Auto-grad**: 매 chain rule 의 automatic.
|
|
- **End-to-end**: 매 loss → 매 every parameter.
|
|
- **Differentiable everything**: physics, rendering, planning.
|
|
|
|
### 매 challenges
|
|
|
|
#### Non-differentiable operation
|
|
- **Discrete choice** (argmax).
|
|
- **Conditional / branch**.
|
|
- **Sampling**.
|
|
- **Solution**: Gumbel-Softmax, REINFORCE, straight-through.
|
|
|
|
#### Numerical
|
|
- **Gradient explosion / vanish**.
|
|
- **Mode collapse** (GAN).
|
|
- **Solution**: BatchNorm, residual, gradient clipping.
|
|
|
|
### 매 framework
|
|
|
|
#### PyTorch
|
|
- 매 dynamic graph.
|
|
- 매 imperative.
|
|
- 매 most popular.
|
|
|
|
#### TensorFlow
|
|
- 매 static + dynamic (TF2).
|
|
- 매 production-strong.
|
|
|
|
#### JAX
|
|
- 매 functional pure.
|
|
- 매 jit + vmap + pmap composability.
|
|
- 매 modern preferred for research.
|
|
|
|
#### Mojo
|
|
- 매 Python-compatible + 매 fast.
|
|
- 매 emerging.
|
|
|
|
### 매 응용
|
|
1. **Neural network**: 매 obvious.
|
|
2. **Differentiable physics** (Brax, Isaac).
|
|
3. **Differentiable rendering** (PyTorch3D, Mitsuba).
|
|
4. **Differentiable planning** (DreamerV3).
|
|
5. **Hyperparameter optim** (gradient-based).
|
|
6. **Architecture search** (DARTS).
|
|
7. **Symbolic regression** (PySR).
|
|
8. **Compiler optimization**.
|
|
|
|
### 매 modern direction
|
|
- **Foundation models**: 매 entire knowledge 의 NN.
|
|
- **Programmable LLM**: 매 LLM-driven flow.
|
|
- **Hybrid**: 매 symbolic + neural.
|
|
|
|
## 💻 패턴
|
|
|
|
### Auto-grad (PyTorch)
|
|
```python
|
|
import torch
|
|
|
|
x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
|
|
y = x.pow(2).sum()
|
|
y.backward()
|
|
print(x.grad) # 매 [2, 4, 6]
|
|
```
|
|
|
|
### JAX (functional)
|
|
```python
|
|
import jax
|
|
import jax.numpy as jnp
|
|
|
|
def f(x):
|
|
return jnp.sum(x ** 2)
|
|
|
|
grad_f = jax.grad(f)
|
|
print(grad_f(jnp.array([1.0, 2.0, 3.0]))) # 매 [2, 4, 6]
|
|
|
|
# 매 compose
|
|
jit_grad_f = jax.jit(grad_f)
|
|
vmap_grad_f = jax.vmap(grad_f)
|
|
```
|
|
|
|
### Gumbel-Softmax (differentiable categorical)
|
|
```python
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
def gumbel_softmax(logits, tau=1.0, hard=False):
|
|
"""매 categorical sampling 의 differentiable approximation."""
|
|
g = -torch.log(-torch.log(torch.rand_like(logits) + 1e-20) + 1e-20)
|
|
y_soft = F.softmax((logits + g) / tau, dim=-1)
|
|
|
|
if hard:
|
|
# 매 straight-through
|
|
y_hard = torch.zeros_like(y_soft).scatter_(-1, y_soft.argmax(-1, keepdim=True), 1.0)
|
|
y = y_hard - y_soft.detach() + y_soft
|
|
else:
|
|
y = y_soft
|
|
return y
|
|
```
|
|
|
|
### REINFORCE (policy gradient)
|
|
```python
|
|
def reinforce_loss(log_probs, rewards):
|
|
"""매 stochastic 의 gradient estimator."""
|
|
return -(log_probs * rewards).sum()
|
|
|
|
# 매 매 episode
|
|
log_probs = []
|
|
rewards = []
|
|
state = env.reset()
|
|
done = False
|
|
while not done:
|
|
action_dist = policy(state)
|
|
action = action_dist.sample()
|
|
log_probs.append(action_dist.log_prob(action))
|
|
state, reward, done, _ = env.step(action)
|
|
rewards.append(reward)
|
|
|
|
# 매 update
|
|
loss = reinforce_loss(torch.stack(log_probs), torch.tensor(rewards))
|
|
loss.backward()
|
|
optimizer.step()
|
|
```
|
|
|
|
### Differentiable physics (Brax)
|
|
```python
|
|
import brax
|
|
from brax import envs
|
|
from brax.training.agents.ppo import train
|
|
|
|
env = envs.create('halfcheetah', batch_size=256)
|
|
# 매 entire physics simulator 의 differentiable
|
|
# 매 gradient 의 policy 의 directly train
|
|
```
|
|
|
|
### Differentiable rendering
|
|
```python
|
|
# 매 PyTorch3D
|
|
import torch
|
|
from pytorch3d.renderer import MeshRenderer, MeshRasterizer, ...
|
|
|
|
renderer = MeshRenderer(rasterizer=MeshRasterizer(...), shader=...)
|
|
mesh = load_mesh(...)
|
|
mesh.verts = mesh.verts.requires_grad_() # 매 mesh vertices 의 learn
|
|
|
|
target_image = load_target()
|
|
for _ in range(n_iters):
|
|
rendered = renderer(mesh)
|
|
loss = (rendered - target_image).pow(2).mean()
|
|
loss.backward()
|
|
mesh.verts.data -= lr * mesh.verts.grad.data
|
|
mesh.verts.grad.zero_()
|
|
```
|
|
|
|
### Hyperparameter as parameter
|
|
```python
|
|
import torch
|
|
|
|
# 매 LR 의 learnable
|
|
lr = torch.tensor(0.01, requires_grad=True)
|
|
inner_optimizer = torch.optim.SGD(model.parameters(), lr=lr.item())
|
|
|
|
# 매 outer loss 의 inner training 의 outcome
|
|
def outer_loss(lr_value):
|
|
inner_optimizer.param_groups[0]['lr'] = lr_value
|
|
train_one_epoch()
|
|
return validation_loss()
|
|
|
|
# 매 implicit differentiation 의 hyperparameter search
|
|
```
|
|
|
|
### NeuroSymbolic (LLM + symbolic)
|
|
```python
|
|
def neurosymbolic_solve(problem):
|
|
# 매 LLM 의 symbolic 의 generate
|
|
sym = llm.generate(f'Translate to Wolfram Alpha: {problem}')
|
|
# 매 symbolic 의 compute (non-differentiable)
|
|
result = wolfram.eval(sym)
|
|
# 매 LLM 의 explain
|
|
return llm.generate(f'Explain {result} for {problem}')
|
|
```
|
|
|
|
### Differentiable algorithm (sorting, etc.)
|
|
```python
|
|
def soft_sort(x, tau=1.0):
|
|
"""매 differentiable sorting (soft)."""
|
|
n = x.size(-1)
|
|
pairwise_diff = x.unsqueeze(-1) - x.unsqueeze(-2)
|
|
P = torch.softmax(pairwise_diff / tau, dim=-1)
|
|
sorted_x = (P * x.unsqueeze(-2)).sum(-1)
|
|
return sorted_x
|
|
```
|
|
|
|
### JAX `vmap` (auto-batching)
|
|
```python
|
|
import jax
|
|
|
|
def loss(params, x, y):
|
|
return ((model(params, x) - y) ** 2).mean()
|
|
|
|
# 매 매 example 의 grad
|
|
single_grad = jax.grad(loss)
|
|
|
|
# 매 batched (auto-vmap)
|
|
batch_grad = jax.vmap(single_grad, in_axes=(None, 0, 0))
|
|
gradients = batch_grad(params, batch_x, batch_y)
|
|
```
|
|
|
|
### Compose JAX `jit + grad + vmap`
|
|
```python
|
|
@jax.jit
|
|
@jax.grad
|
|
def loss_fn(params, x, y):
|
|
return ((model(params, x) - y) ** 2).mean()
|
|
|
|
# 매 매 step
|
|
gradients = loss_fn(params, x, y)
|
|
```
|
|
|
|
## 매 결정 기준
|
|
| 상황 | Approach |
|
|
|---|---|
|
|
| Standard NN | PyTorch / TF |
|
|
| Research / functional | JAX |
|
|
| RL with diff physics | Brax / MuJoCo |
|
|
| Differentiable render | PyTorch3D / Mitsuba |
|
|
| Categorical | Gumbel-Softmax |
|
|
| Sampling-based | REINFORCE |
|
|
| Hybrid sym + neural | LLM + Wolfram |
|
|
| Hyperparameter | Bayesian opt or implicit diff |
|
|
|
|
**기본값**: PyTorch (default) + JAX (research) + Gumbel-Softmax (discrete).
|
|
|
|
## 🔗 Graph
|
|
- 부모: [[Deep Learning]] · [[Optimization]]
|
|
- 변형: [[Software-2-0]] · [[Auto-grad]]
|
|
- 응용: [[JAX]] · [[NEAT]]
|
|
- 매 trick: [[Gumbel-Softmax]]
|
|
- Adjacent: [[Bayesian-Optimization]] · [[Computational_Creativity|Computational-Creativity]] · [[Neural-Symbolic-Integration|Neuro-Symbolic-AI]] · [[Reinforcement-Learning]]
|
|
|
|
## 🤖 LLM 활용
|
|
**언제**: 매 ML algorithm design. 매 differentiable simulation. 매 hybrid neuro-symbolic.
|
|
**언제 X**: 매 highly discrete (use combinatorial). 매 simple algorithm.
|
|
|
|
## ❌ 안티패턴
|
|
- **Non-differentiable 의 force**: 매 wrong tool.
|
|
- **Gradient explosion 의 ignore**: 매 NaN.
|
|
- **No clipping**: 매 unstable.
|
|
- **모든 의 differentiable**: 매 sometimes 매 symbolic 의 better.
|
|
|
|
## 🧪 검증 / 중복
|
|
- Verified (Karpathy "Software 2.0", JAX docs, PyTorch docs, Brax / PyTorch3D papers).
|
|
- 신뢰도 A.
|
|
- Related: [[Bayesian-Optimization]] · [[Computational_Creativity|Computational-Creativity]] · [[Reinforcement-Learning]] · [[Cross-Entropy Loss]] · [[Deep Learning]].
|
|
|
|
## 🕓 Changelog
|
|
| 날짜 | 변경 |
|
|
|---|---|
|
|
| 2026-05-08 | Phase 1 |
|
|
| 2026-05-10 | Manual cleanup — Software 2.0 + 매 PyTorch / JAX / Gumbel / REINFORCE / Brax code |
|