--- id: wiki-2026-0508-joint-optimization title: Joint Optimization category: 10_Wiki/Topics status: verified canonical_id: self aliases: [Multi-Objective Optimization, Co-Optimization, End-to-End Optimization] duplicate_of: none source_trust_level: A confidence_score: 0.9 verification_status: applied tags: [optimization, ML, multi-objective] raw_sources: [] last_reinforced: 2026-05-10 github_commit: pending tech_stack: language: python framework: pytorch-jax --- # Joint Optimization ## 매 한 줄 > **"매 multiple objectives / variables 를 동시에 optimize"**. 매 separate / sequential optimization 보다 매 globally better solution 도달 가능 — 매 cost: 매 higher complexity, 매 risk: 매 conflicting gradients. 매 modern DL (end-to-end training), 매 RL (actor-critic), 매 chip design (DSE) 의 매 핵심. ## 매 핵심 ### 매 왜 jointly? - **Coupling**: 매 variables 의 interaction 강 → 매 separate solve 매 suboptimal. - **Information sharing**: 매 shared representation / gradient → 매 mutual benefit. - **End-to-end**: 매 pipeline 의 손실 누적 X. ### 매 challenges - **Conflicting gradients**: 매 objectives 매 push opposite directions. - **Scaling**: 매 loss magnitudes 매 mismatched → 매 dominant loss problem. - **Local minima**: 매 joint landscape 매 더 rugged. - **Compute**: 매 N variables 매 jointly → search space exponential. ### 매 응용 1. **Multi-task learning**: 매 shared encoder + 매 multiple heads. 2. **Actor-critic RL**: 매 policy + value 매 jointly. 3. **HW/SW co-design**: 매 chip floorplan + scheduler 매 jointly. 4. **Pareto front**: 매 cost vs latency 매 frontier. ## 💻 패턴 ### Weighted sum (simplest) ```python import torch def joint_loss(pred1, pred2, y1, y2, w=(0.5, 0.5)): l1 = torch.nn.functional.cross_entropy(pred1, y1) l2 = torch.nn.functional.mse_loss(pred2, y2) return w[0] * l1 + w[1] * l2 ``` ### GradNorm (auto-balance) ```python # Chen et al 2018 — 매 dynamic loss weighting class GradNorm: def __init__(self, n_tasks, alpha=1.5): self.weights = torch.ones(n_tasks, requires_grad=True) self.alpha = alpha def update(self, losses, shared_params): # 매 normalize 매 gradient magnitudes across tasks grads = [torch.autograd.grad(l, shared_params, retain_graph=True) for l in losses] norms = torch.stack([g[0].norm() for g in grads]) target = norms.mean() * (losses / losses.mean()) ** self.alpha gradnorm_loss = (norms - target.detach()).abs().sum() return gradnorm_loss ``` ### MGDA (Multi-Gradient Descent) ```python # Sener & Koltun 2018 — 매 Pareto-optimal direction 찾기 import numpy as np def mgda_solver(grads): """grads: list of gradient vectors per task.""" # 매 minimum-norm point in convex hull G = np.stack([g.flatten() for g in grads]) # solve min ||sum α_i g_i||² s.t. α≥0, sum α=1 from scipy.optimize import minimize def obj(a): return np.linalg.norm(a @ G) ** 2 a0 = np.ones(len(grads)) / len(grads) cons = [{"type": "eq", "fun": lambda a: a.sum() - 1}] bnds = [(0, 1)] * len(grads) res = minimize(obj, a0, constraints=cons, bounds=bnds) return res.x # 매 Pareto direction ``` ### Actor-critic joint update ```python # PPO-style joint optimization def actor_critic_loss(states, actions, advantages, returns, policy, value): log_p = policy.log_prob(states, actions) actor_loss = -(log_p * advantages).mean() critic_loss = (value(states) - returns).pow(2).mean() entropy = policy.entropy(states).mean() return actor_loss + 0.5 * critic_loss - 0.01 * entropy ``` ### Pareto frontier sampling ```python # 매 multi-objective 의 frontier 발견 def pareto_front(solutions): """solutions: list of (obj1, obj2) tuples (minimize both).""" front = [] for s in solutions: dominated = any( s2[0] <= s[0] and s2[1] <= s[1] and s2 != s for s2 in solutions ) if not dominated: front.append(s) return front ``` ## 매 결정 기준 | 상황 | Strategy | |---|---| | 매 objectives 매 aligned | Weighted sum (simple) | | 매 objectives 매 conflicting | MGDA / PCGrad | | 매 magnitude 매 mismatched | GradNorm | | 매 trade-off 매 explore 필요 | Pareto frontier sweep | | 매 RL actor + critic | Joint PPO/SAC | **기본값**: Weighted sum 시작 → 매 imbalance 발견시 GradNorm 도입. ## 🔗 Graph - 부모: [[Optimization]] - 응용: [[Actor-Critic]] ## 🤖 LLM 활용 **언제**: 매 loss function design 매 multi-objective, 매 gradient conflict diagnosis, 매 Pareto analysis explanation. **언제 X**: 매 single-objective optimization — over-complication. ## ❌ 안티패턴 - **Random weight tuning**: 매 grid search w/o GradNorm → 매 unstable. - **Ignore gradient conflict**: 매 cosine(g1,g2) < 0 무시 → 매 destructive interference. - **Premature joint**: 매 separate pretrain → joint finetune 매 더 좋은 경우 많음. ## 🧪 검증 / 중복 - Verified (Chen 2018 GradNorm; Sener & Koltun 2018 MGDA; Yu 2020 PCGrad; Schulman 2017 PPO). - 신뢰도 A. ## 🕓 Changelog | 날짜 | 변경 | |---|---| | 2026-05-08 | Phase 1 | | 2026-05-10 | Manual cleanup — multi-objective optimization patterns + Pareto |