--- id: wiki-2026-0508-focal-loss title: Focal Loss category: 10_Wiki/Topics status: verified canonical_id: self aliases: [focal loss, class imbalance, RetinaNet, hard example mining, dense detection] duplicate_of: none source_trust_level: A confidence_score: 0.97 verification_status: applied tags: [deep-learning, loss-function, focal-loss, class-imbalance, object-detection, retinanet] raw_sources: [] last_reinforced: 2026-05-10 github_commit: pending tech_stack: language: Python framework: PyTorch / TensorFlow --- # Focal Loss ## 매 한 줄 > **"매 cross-entropy 의 의 의 의 (1-p_t)^γ factor"**. Lin 2017 (RetinaNet). 매 easy example 의 down-weight + 매 hard example 의 focus. 매 dense object detection 의 enable. 매 modern: 매 detection, 매 imbalanced classification, 매 instance segmentation. ## 매 핵심 ### 매 formula ``` FL(p_t) = -α_t (1 - p_t)^γ log(p_t) ``` - p_t: 매 predicted prob 의 true class. - γ: 매 focusing parameter (typically 2). - α: 매 class weight (balance). ### 매 motivation - 매 dense detection: 매 1000s neg vs 10s pos. - 매 CE 의 easy negs 의 dominate. - 매 (1-p_t)^γ 의 의 well-classified 의 의 의 weight ↓. ### 매 응용 1. **Object detection** (RetinaNet, FCOS, ATSS). 2. **Instance segmentation**. 3. **Imbalanced classification**. 4. **Medical** (rare disease). 5. **Fraud / anomaly**. ### 매 alternative - **Class weights** (CE). - **Online Hard Example Mining (OHEM)**. - **Balanced sampling**. - **Asymmetric loss** (multi-label). - **Class-balanced loss** (Cui 2019). ## 💻 패턴 ### Focal loss (binary) ```python import torch import torch.nn.functional as F def focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0): p = torch.sigmoid(logits) p_t = p * targets + (1 - p) * (1 - targets) alpha_t = alpha * targets + (1 - alpha) * (1 - targets) ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none') focal = alpha_t * (1 - p_t).pow(gamma) * ce return focal.mean() ``` ### Focal loss (multi-class) ```python def focal_loss_multi(logits, targets, gamma=2.0, alpha=None): """logits: [B, C], targets: [B].""" log_p = F.log_softmax(logits, dim=-1) log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze(1) pt = log_pt.exp() focal = -((1 - pt) ** gamma) * log_pt if alpha is not None: focal = alpha[targets] * focal return focal.mean() ``` ### As nn.Module ```python class FocalLoss(torch.nn.Module): def __init__(self, alpha=0.25, gamma=2.0): super().__init__() self.alpha = alpha self.gamma = gamma def forward(self, logits, targets): return focal_loss_binary(logits, targets, self.alpha, self.gamma) ``` ### RetinaNet detection ```python class RetinaNetLoss(torch.nn.Module): def __init__(self, n_classes, gamma=2.0, alpha=0.25): super().__init__() self.classification = FocalLoss(alpha=alpha, gamma=gamma) self.regression = torch.nn.SmoothL1Loss() def forward(self, cls_logits, box_preds, cls_targets, box_targets, fg_mask): cls_loss = self.classification(cls_logits, cls_targets) n_fg = fg_mask.sum().clamp(min=1) box_loss = self.regression(box_preds[fg_mask], box_targets[fg_mask]) return cls_loss + box_loss / n_fg ``` ### Class-balanced focal (Cui 2019) ```python def class_balanced_focal(logits, targets, samples_per_class, beta=0.999, gamma=2.0): """매 effective number per class.""" eff_num = 1 - beta ** samples_per_class weights = (1 - beta) / eff_num weights = weights / weights.sum() * len(samples_per_class) log_p = F.log_softmax(logits, -1) log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze() pt = log_pt.exp() return (-weights[targets] * (1 - pt) ** gamma * log_pt).mean() ``` ### Asymmetric focal (multi-label) ```python def asymmetric_focal(logits, targets, gamma_pos=1, gamma_neg=4, clip=0.05): """매 Ben-Baruch 2020. 매 asymmetric γ.""" p = torch.sigmoid(logits) p_neg = (p - clip).clamp(min=0) pos_loss = targets * (1 - p) ** gamma_pos * torch.log(p.clamp(min=1e-8)) neg_loss = (1 - targets) * p_neg ** gamma_neg * torch.log((1 - p).clamp(min=1e-8)) return -(pos_loss + neg_loss).mean() ``` ### γ tuning ```python # 매 typical γ GAMMA_VALUES = { 'extreme_imbalance': 5.0, 'object_detection': 2.0, # RetinaNet 'medical_rare_disease': 3.0, 'mild_imbalance': 1.0, 'no_imbalance': 0, # 매 = CE } ``` ### Focal vs CE comparison ```python def compare_losses(p_easy=0.95, p_hard=0.5): """매 easy 매 hard 의 contribution.""" ce_easy = -torch.log(torch.tensor(p_easy)) ce_hard = -torch.log(torch.tensor(p_hard)) fl_easy = (1 - p_easy) ** 2 * ce_easy fl_hard = (1 - p_hard) ** 2 * ce_hard print(f'CE: easy {ce_easy:.3f}, hard {ce_hard:.3f}, ratio {ce_hard/ce_easy:.1f}x') print(f'Focal γ=2: easy {fl_easy:.4f}, hard {fl_hard:.3f}, ratio {fl_hard/fl_easy:.0f}x') # 매 focal 의 hard 의 의 의 의 dominate ``` ### OHEM alternative ```python def ohem_loss(logits, targets, ratio=0.25): """매 매 batch 의 매 hardest examples만 keep.""" losses = F.cross_entropy(logits, targets, reduction='none') n_keep = int(len(losses) * ratio) topk = losses.topk(n_keep)[0] return topk.mean() ``` ### Combine: focal + dice (segmentation) ```python def focal_dice(logits, targets): fl = focal_loss_binary(logits, targets) p = torch.sigmoid(logits) intersection = (p * targets).sum() dice = 1 - 2 * intersection / (p.sum() + targets.sum() + 1) return fl + dice ``` ### Eval metric (PR-AUC for imbalanced) ```python from sklearn.metrics import average_precision_score def eval_imbalanced(probs, targets): return { 'pr_auc': average_precision_score(targets, probs), # 매 not accuracy — meaningless for imbalanced } ``` ## 매 결정 기준 | 상황 | Loss | |---|---| | Dense object detection | Focal (γ=2) | | Mild imbalance | Class-weighted CE | | Extreme imbalance | Focal (γ=3-5) | | Multi-label imbalanced | Asymmetric focal | | Segmentation | Focal + Dice | | Easy task | Plain CE | **기본값**: 매 imbalance > 100:1 → focal γ=2 + α=0.25. 매 multi-label → asymmetric. 매 evaluation = PR-AUC, not accuracy. ## 🔗 Graph - 부모: [[Loss-Function]] · [[Object-Detection]] - 응용: [[RetinaNet]] - Adjacent: [[Cross-Entropy-Loss]] ## 🤖 LLM 활용 **언제**: 매 detection. 매 imbalanced. 매 medical. **언제 X**: 매 balanced. 매 regression. ## ❌ 안티패턴 - **Focal for balanced**: 매 underfit. - **γ too high**: 매 most examples 의 ignore. - **No α**: 매 class balance 의 still off. - **Accuracy metric**: 매 imbalanced 의 misleading. ## 🧪 검증 / 중복 - Verified (Lin RetinaNet 2017, Cui 2019, Ben-Baruch ASL 2020). - 신뢰도 A. ## 🕓 Changelog | 날짜 | 변경 | |---|---| | 2026-04-26 | LOSS-002 auto | | 2026-05-08 | Phase 1 | | 2026-05-10 | Manual cleanup — focal + 매 binary / multi / asymmetric / class-balanced code |