--- id: wiki-2026-0508-generalization-in-ai title: Generalization in AI category: 10_Wiki/Topics status: verified canonical_id: self aliases: [generalization, OOD, distribution shift, robustness, double descent, scaling laws] duplicate_of: none source_trust_level: A confidence_score: 0.96 verification_status: applied tags: [ml, generalization, ood, robustness, scaling, double-descent, foundation-model] raw_sources: [] last_reinforced: 2026-05-10 github_commit: pending tech_stack: language: Python applicable_to: [ML Theory, Foundation Models, Robustness] --- # Generalization in AI ## 매 한 줄 > **"매 unseen data 의 의 의 perform"**. 매 train ↔ test gap. 매 modern: 매 over-parameterization paradox, 매 double descent (Belkin), 매 grokking, 매 OOD robustness, 매 foundation model emergent generalization. ## 매 핵심 ### 매 traditional view - **Overfitting**: 매 capacity > complexity. - **Underfitting**: 매 capacity < complexity. - **Sweet spot**: 매 bias-variance trade-off. ### 매 modern view (DL) - **Double descent** (Belkin 2019): 매 over-param → 매 generalize. - **Grokking** (Power 2022): 매 long-after-overfit → 매 generalize. - **Lottery ticket** (Frankle): 매 sparse subnet. - **Implicit regularization** (SGD). - **Flat minima** → 매 better generalize. ### 매 scaling laws - **Kaplan 2020**: power law (loss vs N, D, C). - **Chinchilla** (Hoffmann 2022): 매 D = 20·N optimal. - **Llama 3 / 4**: 매 over-train 의 trend. ### 매 OOD robustness - **Distribution shift**: covariate, label, concept. - **Group robustness** (worst-case). - **Invariant features** (causal). - **Domain generalization**. ### 매 응용 1. **Production ML monitoring**. 2. **Self-driving safety**. 3. **Medical AI**. 4. **Foundation model evals**. 5. **Few-shot transfer**. ## 💻 패턴 ### Train / val / test split ```python from sklearn.model_selection import train_test_split X_tr, X_temp, y_tr, y_temp = train_test_split(X, y, test_size=0.3, stratify=y) X_val, X_te, y_val, y_te = train_test_split(X_temp, y_temp, test_size=0.5, stratify=y_temp) ``` ### Detect overfit ```python def overfit_check(train_loss, val_loss, threshold=0.1): gap = (val_loss - train_loss) / train_loss return gap > threshold ``` ### Early stopping (val) ```python class EarlyStop: def __init__(self, patience=5): self.patience = patience; self.best = float('inf'); self.bad = 0 def step(self, val_loss): if val_loss < self.best: self.best = val_loss; self.bad = 0; return False self.bad += 1; return self.bad > self.patience ``` ### Double descent visualization ```python def double_descent_curve(model_capacity_range, loss_fn): """매 small → optimum → big = train ↑ but generalize ↑.""" losses = [] for cap in model_capacity_range: m = build_model(cap).fit(X_train, y_train) losses.append(loss_fn(m, X_val, y_val)) return losses # 매 W-shaped curve ``` ### OOD detection (Mahalanobis) ```python def ood_score(test_features, train_features): mu = train_features.mean(0) cov_inv = np.linalg.pinv(np.cov(train_features.T)) diff = test_features - mu return np.sqrt(np.einsum('bi,ij,bj->b', diff, cov_inv, diff)) ``` ### Distribution shift (PSI) ```python def population_stability_index(expected, actual, bins=10): e_hist, edges = np.histogram(expected, bins=bins) a_hist, _ = np.histogram(actual, bins=edges) e_pct = e_hist / len(expected) + 1e-9 a_pct = a_hist / len(actual) + 1e-9 return ((a_pct - e_pct) * np.log(a_pct / e_pct)).sum() # 매 < 0.1: stable; > 0.25: significant shift ``` ### Group robustness (Worst-Group) ```python def worst_group_acc(predictions, labels, groups): group_accs = {} for g in np.unique(groups): mask = groups == g group_accs[g] = (predictions[mask] == labels[mask]).mean() return min(group_accs.values()), group_accs ``` ### Domain generalization (DRO) ```python def dro_loss(losses_per_group, eta=1.0): """매 distributionally robust opt.""" return np.exp(losses_per_group * eta).mean() ``` ### Augmentation (improve generalization) ```python import torchvision.transforms as T augment = T.Compose([ T.RandomHorizontalFlip(), T.RandomCrop(32, padding=4), T.ColorJitter(0.2, 0.2, 0.2), T.AutoAugment(), ]) ``` ### Mixup (interpolation) ```python def mixup(x, y, alpha=0.4): lam = np.random.beta(alpha, alpha) idx = torch.randperm(x.size(0)) x_mix = lam * x + (1 - lam) * x[idx] y_a, y_b = y, y[idx] return x_mix, y_a, y_b, lam ``` ### SAM (Sharpness-Aware Minimization) ```python from torch.optim import Optimizer class SAM(Optimizer): def __init__(self, params, base_optim, rho=0.05): super().__init__(params, dict()) self.base = base_optim; self.rho = rho ``` ### Flat-minima detection ```python def flatness(model, loss_fn, X, y, eps=0.01, n_perturb=20): base = loss_fn(model(X), y).item() perturbed = [] for _ in range(n_perturb): for p in model.parameters(): p.data += eps * torch.randn_like(p) perturbed.append(loss_fn(model(X), y).item()) for p in model.parameters(): p.data -= eps * torch.randn_like(p) # 매 simplified return np.mean(perturbed) - base ``` ### Scaling law extrapolation ```python def power_law(N, alpha, beta, eps): return alpha + beta / N ** eps from scipy.optimize import curve_fit def fit_scaling(model_sizes, losses): return curve_fit(power_law, model_sizes, losses, p0=[1, 1, 0.5])[0] ``` ### Robustness eval ```python def robustness_eval(model, attacks): results = {} for name, attack_fn in attacks.items(): adv_X = attack_fn(model, X_test, y_test) results[name] = (model(adv_X).argmax(-1) == y_test).float().mean().item() return results ``` ### Calibration (ECE) ```python def expected_calibration_error(probs, labels, n_bins=10): bin_edges = np.linspace(0, 1, n_bins + 1) ece = 0 for i in range(n_bins): mask = (probs >= bin_edges[i]) & (probs < bin_edges[i+1]) if mask.sum() == 0: continue bin_acc = labels[mask].mean() bin_conf = probs[mask].mean() ece += (mask.sum() / len(probs)) * abs(bin_acc - bin_conf) return ece ``` ### Transfer learning eval ```python def transfer_score(source_model, target_X, target_y): """매 frozen feature → linear probe.""" feats = source_model.encode(target_X) from sklearn.linear_model import LogisticRegression return LogisticRegression().fit(feats, target_y).score(feats, target_y) ``` ## 매 결정 기준 | 상황 | Approach | |---|---| | Overfit (small data) | Augment + early stop | | Underfit | More capacity | | Distribution shift | Monitoring + retrain | | OOD robustness | Augment + DRO | | Few-shot | Foundation model + transfer | | Production | + monitor + calibration | **기본값**: 매 augmentation + early stop + flat min (SAM/SWA) + OOD detect + monitor PSI in prod. ## 🔗 Graph - 부모: [[Machine-Learning]] - 변형: [[Double-Descent]] - 응용: [[Foundation-Models]] · [[Domain-Adaptation]] - Adjacent: [[Epistemic-Uncertainty]] · [[Concept-Drift]] ## 🤖 LLM 활용 **언제**: 매 모든 ML deployment. 매 monitoring. 매 robustness eval. **언제 X**: 매 train-only academic. ## ❌ 안티패턴 - **Test set leak**: 매 fake high score. - **No OOD eval**: 매 production failure. - **Capacity ↓ 의 always**: 매 modern DL 의 reverse. - **No calibration**: 매 confidence misleading. - **No drift monitor**: 매 silent degrade. ## 🧪 검증 / 중복 - Verified (Belkin 2019, Power Grokking 2022, Hoffmann Chinchilla, Vapnik SLT). - 신뢰도 A. ## 🕓 Changelog | 날짜 | 변경 | |---|---| | 2026-04-20 | Auto | | 2026-05-08 | Phase 1 | | 2026-05-10 | Manual cleanup — bias-var + 매 double descent / OOD / DRO / SAM / scaling code |