Files
2nd/10_Wiki/Topics/AI_and_ML/Focal-Loss.md
T
koriweb d8a80f6272 chore(wiki): dangling 링크 canonical 정규화 (768파일/1200건)
이름만 다른(표기 변형) [[위키링크]]를 대상 문서의 canonical 제목으로 치환해
끊겼던 1,200개 링크를 연결. 제목/파일명 정규화 일치만 적용하고 별칭 매칭은
과병합 위험으로 제외(애매성 가드). 원본은 _link_reconcile_backup/ 에 백업.
도구: Datacollect/scripts/link_reconcile_apply.mjs

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-08 12:24:15 +09:00

232 lines
6.9 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
---
id: wiki-2026-0508-focal-loss
title: Focal Loss
category: 10_Wiki/Topics
status: verified
canonical_id: self
aliases: [focal loss, class imbalance, RetinaNet, hard example mining, dense detection]
duplicate_of: none
source_trust_level: A
confidence_score: 0.97
verification_status: applied
tags: [deep-learning, loss-function, focal-loss, class-imbalance, object-detection, retinanet]
raw_sources: []
last_reinforced: 2026-05-10
github_commit: pending
tech_stack:
language: Python
framework: PyTorch / TensorFlow
---
# Focal Loss
## 매 한 줄
> **"매 cross-entropy 의 의 의 의 (1-p_t)^γ factor"**. Lin 2017 (RetinaNet). 매 easy example 의 down-weight + 매 hard example 의 focus. 매 dense object detection 의 enable. 매 modern: 매 detection, 매 imbalanced classification, 매 instance segmentation.
## 매 핵심
### 매 formula
```
FL(p_t) = -α_t (1 - p_t)^γ log(p_t)
```
- p_t: 매 predicted prob 의 true class.
- γ: 매 focusing parameter (typically 2).
- α: 매 class weight (balance).
### 매 motivation
- 매 dense detection: 매 1000s neg vs 10s pos.
- 매 CE 의 easy negs 의 dominate.
- 매 (1-p_t)^γ 의 의 well-classified 의 의 의 weight ↓.
### 매 응용
1. **Object detection** (RetinaNet, FCOS, ATSS).
2. **Instance segmentation**.
3. **Imbalanced classification**.
4. **Medical** (rare disease).
5. **Fraud / anomaly**.
### 매 alternative
- **Class weights** (CE).
- **Online Hard Example Mining (OHEM)**.
- **Balanced sampling**.
- **Asymmetric loss** (multi-label).
- **Class-balanced loss** (Cui 2019).
## 💻 패턴
### Focal loss (binary)
```python
import torch
import torch.nn.functional as F
def focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0):
p = torch.sigmoid(logits)
p_t = p * targets + (1 - p) * (1 - targets)
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
focal = alpha_t * (1 - p_t).pow(gamma) * ce
return focal.mean()
```
### Focal loss (multi-class)
```python
def focal_loss_multi(logits, targets, gamma=2.0, alpha=None):
"""logits: [B, C], targets: [B]."""
log_p = F.log_softmax(logits, dim=-1)
log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze(1)
pt = log_pt.exp()
focal = -((1 - pt) ** gamma) * log_pt
if alpha is not None:
focal = alpha[targets] * focal
return focal.mean()
```
### As nn.Module
```python
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets):
return focal_loss_binary(logits, targets, self.alpha, self.gamma)
```
### RetinaNet detection
```python
class RetinaNetLoss(torch.nn.Module):
def __init__(self, n_classes, gamma=2.0, alpha=0.25):
super().__init__()
self.classification = FocalLoss(alpha=alpha, gamma=gamma)
self.regression = torch.nn.SmoothL1Loss()
def forward(self, cls_logits, box_preds, cls_targets, box_targets, fg_mask):
cls_loss = self.classification(cls_logits, cls_targets)
n_fg = fg_mask.sum().clamp(min=1)
box_loss = self.regression(box_preds[fg_mask], box_targets[fg_mask])
return cls_loss + box_loss / n_fg
```
### Class-balanced focal (Cui 2019)
```python
def class_balanced_focal(logits, targets, samples_per_class, beta=0.999, gamma=2.0):
"""매 effective number per class."""
eff_num = 1 - beta ** samples_per_class
weights = (1 - beta) / eff_num
weights = weights / weights.sum() * len(samples_per_class)
log_p = F.log_softmax(logits, -1)
log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze()
pt = log_pt.exp()
return (-weights[targets] * (1 - pt) ** gamma * log_pt).mean()
```
### Asymmetric focal (multi-label)
```python
def asymmetric_focal(logits, targets, gamma_pos=1, gamma_neg=4, clip=0.05):
"""매 Ben-Baruch 2020. 매 asymmetric γ."""
p = torch.sigmoid(logits)
p_neg = (p - clip).clamp(min=0)
pos_loss = targets * (1 - p) ** gamma_pos * torch.log(p.clamp(min=1e-8))
neg_loss = (1 - targets) * p_neg ** gamma_neg * torch.log((1 - p).clamp(min=1e-8))
return -(pos_loss + neg_loss).mean()
```
### γ tuning
```python
# 매 typical γ
GAMMA_VALUES = {
'extreme_imbalance': 5.0,
'object_detection': 2.0, # RetinaNet
'medical_rare_disease': 3.0,
'mild_imbalance': 1.0,
'no_imbalance': 0, # 매 = CE
}
```
### Focal vs CE comparison
```python
def compare_losses(p_easy=0.95, p_hard=0.5):
"""매 easy 매 hard 의 contribution."""
ce_easy = -torch.log(torch.tensor(p_easy))
ce_hard = -torch.log(torch.tensor(p_hard))
fl_easy = (1 - p_easy) ** 2 * ce_easy
fl_hard = (1 - p_hard) ** 2 * ce_hard
print(f'CE: easy {ce_easy:.3f}, hard {ce_hard:.3f}, ratio {ce_hard/ce_easy:.1f}x')
print(f'Focal γ=2: easy {fl_easy:.4f}, hard {fl_hard:.3f}, ratio {fl_hard/fl_easy:.0f}x')
# 매 focal 의 hard 의 의 의 의 dominate
```
### OHEM alternative
```python
def ohem_loss(logits, targets, ratio=0.25):
"""매 매 batch 의 매 hardest examples만 keep."""
losses = F.cross_entropy(logits, targets, reduction='none')
n_keep = int(len(losses) * ratio)
topk = losses.topk(n_keep)[0]
return topk.mean()
```
### Combine: focal + dice (segmentation)
```python
def focal_dice(logits, targets):
fl = focal_loss_binary(logits, targets)
p = torch.sigmoid(logits)
intersection = (p * targets).sum()
dice = 1 - 2 * intersection / (p.sum() + targets.sum() + 1)
return fl + dice
```
### Eval metric (PR-AUC for imbalanced)
```python
from sklearn.metrics import average_precision_score
def eval_imbalanced(probs, targets):
return {
'pr_auc': average_precision_score(targets, probs),
# 매 not accuracy — meaningless for imbalanced
}
```
## 매 결정 기준
| 상황 | Loss |
|---|---|
| Dense object detection | Focal (γ=2) |
| Mild imbalance | Class-weighted CE |
| Extreme imbalance | Focal (γ=3-5) |
| Multi-label imbalanced | Asymmetric focal |
| Segmentation | Focal + Dice |
| Easy task | Plain CE |
**기본값**: 매 imbalance > 100:1 → focal γ=2 + α=0.25. 매 multi-label → asymmetric. 매 evaluation = PR-AUC, not accuracy.
## 🔗 Graph
- 부모: [[Loss-Function]] · [[Object-Detection]]
- 응용: [[RetinaNet]]
- Adjacent: [[Cross-Entropy Loss]]
## 🤖 LLM 활용
**언제**: 매 detection. 매 imbalanced. 매 medical.
**언제 X**: 매 balanced. 매 regression.
## ❌ 안티패턴
- **Focal for balanced**: 매 underfit.
- **γ too high**: 매 most examples 의 ignore.
- **No α**: 매 class balance 의 still off.
- **Accuracy metric**: 매 imbalanced 의 misleading.
## 🧪 검증 / 중복
- Verified (Lin RetinaNet 2017, Cui 2019, Ben-Baruch ASL 2020).
- 신뢰도 A.
## 🕓 Changelog
| 날짜 | 변경 |
|---|---|
| 2026-04-26 | LOSS-002 auto |
| 2026-05-08 | Phase 1 |
| 2026-05-10 | Manual cleanup — focal + 매 binary / multi / asymmetric / class-balanced code |