f8b21af4be
10_Wiki/Topics 대규모 정리: - 오류 캡처/미완성 stub 문서 227개 제거 - 교차폴더 중복 43클러스터 병합 (63파일 → redirect) - 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건 - 카테고리 MOC 6개 신규 생성 - Graph 섹션 미해결 related-keyword 링크 10,058건 제거 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
232 lines
6.9 KiB
Markdown
232 lines
6.9 KiB
Markdown
---
|
||
id: wiki-2026-0508-focal-loss
|
||
title: Focal Loss
|
||
category: 10_Wiki/Topics
|
||
status: verified
|
||
canonical_id: self
|
||
aliases: [focal loss, class imbalance, RetinaNet, hard example mining, dense detection]
|
||
duplicate_of: none
|
||
source_trust_level: A
|
||
confidence_score: 0.97
|
||
verification_status: applied
|
||
tags: [deep-learning, loss-function, focal-loss, class-imbalance, object-detection, retinanet]
|
||
raw_sources: []
|
||
last_reinforced: 2026-05-10
|
||
github_commit: pending
|
||
tech_stack:
|
||
language: Python
|
||
framework: PyTorch / TensorFlow
|
||
---
|
||
|
||
# Focal Loss
|
||
|
||
## 매 한 줄
|
||
> **"매 cross-entropy 의 의 의 의 (1-p_t)^γ factor"**. Lin 2017 (RetinaNet). 매 easy example 의 down-weight + 매 hard example 의 focus. 매 dense object detection 의 enable. 매 modern: 매 detection, 매 imbalanced classification, 매 instance segmentation.
|
||
|
||
## 매 핵심
|
||
|
||
### 매 formula
|
||
```
|
||
FL(p_t) = -α_t (1 - p_t)^γ log(p_t)
|
||
```
|
||
- p_t: 매 predicted prob 의 true class.
|
||
- γ: 매 focusing parameter (typically 2).
|
||
- α: 매 class weight (balance).
|
||
|
||
### 매 motivation
|
||
- 매 dense detection: 매 1000s neg vs 10s pos.
|
||
- 매 CE 의 easy negs 의 dominate.
|
||
- 매 (1-p_t)^γ 의 의 well-classified 의 의 의 weight ↓.
|
||
|
||
### 매 응용
|
||
1. **Object detection** (RetinaNet, FCOS, ATSS).
|
||
2. **Instance segmentation**.
|
||
3. **Imbalanced classification**.
|
||
4. **Medical** (rare disease).
|
||
5. **Fraud / anomaly**.
|
||
|
||
### 매 alternative
|
||
- **Class weights** (CE).
|
||
- **Online Hard Example Mining (OHEM)**.
|
||
- **Balanced sampling**.
|
||
- **Asymmetric loss** (multi-label).
|
||
- **Class-balanced loss** (Cui 2019).
|
||
|
||
## 💻 패턴
|
||
|
||
### Focal loss (binary)
|
||
```python
|
||
import torch
|
||
import torch.nn.functional as F
|
||
|
||
def focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0):
|
||
p = torch.sigmoid(logits)
|
||
p_t = p * targets + (1 - p) * (1 - targets)
|
||
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
||
ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
|
||
focal = alpha_t * (1 - p_t).pow(gamma) * ce
|
||
return focal.mean()
|
||
```
|
||
|
||
### Focal loss (multi-class)
|
||
```python
|
||
def focal_loss_multi(logits, targets, gamma=2.0, alpha=None):
|
||
"""logits: [B, C], targets: [B]."""
|
||
log_p = F.log_softmax(logits, dim=-1)
|
||
log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze(1)
|
||
pt = log_pt.exp()
|
||
|
||
focal = -((1 - pt) ** gamma) * log_pt
|
||
if alpha is not None:
|
||
focal = alpha[targets] * focal
|
||
return focal.mean()
|
||
```
|
||
|
||
### As nn.Module
|
||
```python
|
||
class FocalLoss(torch.nn.Module):
|
||
def __init__(self, alpha=0.25, gamma=2.0):
|
||
super().__init__()
|
||
self.alpha = alpha
|
||
self.gamma = gamma
|
||
|
||
def forward(self, logits, targets):
|
||
return focal_loss_binary(logits, targets, self.alpha, self.gamma)
|
||
```
|
||
|
||
### RetinaNet detection
|
||
```python
|
||
class RetinaNetLoss(torch.nn.Module):
|
||
def __init__(self, n_classes, gamma=2.0, alpha=0.25):
|
||
super().__init__()
|
||
self.classification = FocalLoss(alpha=alpha, gamma=gamma)
|
||
self.regression = torch.nn.SmoothL1Loss()
|
||
|
||
def forward(self, cls_logits, box_preds, cls_targets, box_targets, fg_mask):
|
||
cls_loss = self.classification(cls_logits, cls_targets)
|
||
n_fg = fg_mask.sum().clamp(min=1)
|
||
box_loss = self.regression(box_preds[fg_mask], box_targets[fg_mask])
|
||
return cls_loss + box_loss / n_fg
|
||
```
|
||
|
||
### Class-balanced focal (Cui 2019)
|
||
```python
|
||
def class_balanced_focal(logits, targets, samples_per_class, beta=0.999, gamma=2.0):
|
||
"""매 effective number per class."""
|
||
eff_num = 1 - beta ** samples_per_class
|
||
weights = (1 - beta) / eff_num
|
||
weights = weights / weights.sum() * len(samples_per_class)
|
||
|
||
log_p = F.log_softmax(logits, -1)
|
||
log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze()
|
||
pt = log_pt.exp()
|
||
return (-weights[targets] * (1 - pt) ** gamma * log_pt).mean()
|
||
```
|
||
|
||
### Asymmetric focal (multi-label)
|
||
```python
|
||
def asymmetric_focal(logits, targets, gamma_pos=1, gamma_neg=4, clip=0.05):
|
||
"""매 Ben-Baruch 2020. 매 asymmetric γ."""
|
||
p = torch.sigmoid(logits)
|
||
p_neg = (p - clip).clamp(min=0)
|
||
|
||
pos_loss = targets * (1 - p) ** gamma_pos * torch.log(p.clamp(min=1e-8))
|
||
neg_loss = (1 - targets) * p_neg ** gamma_neg * torch.log((1 - p).clamp(min=1e-8))
|
||
return -(pos_loss + neg_loss).mean()
|
||
```
|
||
|
||
### γ tuning
|
||
```python
|
||
# 매 typical γ
|
||
GAMMA_VALUES = {
|
||
'extreme_imbalance': 5.0,
|
||
'object_detection': 2.0, # RetinaNet
|
||
'medical_rare_disease': 3.0,
|
||
'mild_imbalance': 1.0,
|
||
'no_imbalance': 0, # 매 = CE
|
||
}
|
||
```
|
||
|
||
### Focal vs CE comparison
|
||
```python
|
||
def compare_losses(p_easy=0.95, p_hard=0.5):
|
||
"""매 easy 매 hard 의 contribution."""
|
||
ce_easy = -torch.log(torch.tensor(p_easy))
|
||
ce_hard = -torch.log(torch.tensor(p_hard))
|
||
|
||
fl_easy = (1 - p_easy) ** 2 * ce_easy
|
||
fl_hard = (1 - p_hard) ** 2 * ce_hard
|
||
|
||
print(f'CE: easy {ce_easy:.3f}, hard {ce_hard:.3f}, ratio {ce_hard/ce_easy:.1f}x')
|
||
print(f'Focal γ=2: easy {fl_easy:.4f}, hard {fl_hard:.3f}, ratio {fl_hard/fl_easy:.0f}x')
|
||
# 매 focal 의 hard 의 의 의 의 dominate
|
||
```
|
||
|
||
### OHEM alternative
|
||
```python
|
||
def ohem_loss(logits, targets, ratio=0.25):
|
||
"""매 매 batch 의 매 hardest examples만 keep."""
|
||
losses = F.cross_entropy(logits, targets, reduction='none')
|
||
n_keep = int(len(losses) * ratio)
|
||
topk = losses.topk(n_keep)[0]
|
||
return topk.mean()
|
||
```
|
||
|
||
### Combine: focal + dice (segmentation)
|
||
```python
|
||
def focal_dice(logits, targets):
|
||
fl = focal_loss_binary(logits, targets)
|
||
p = torch.sigmoid(logits)
|
||
intersection = (p * targets).sum()
|
||
dice = 1 - 2 * intersection / (p.sum() + targets.sum() + 1)
|
||
return fl + dice
|
||
```
|
||
|
||
### Eval metric (PR-AUC for imbalanced)
|
||
```python
|
||
from sklearn.metrics import average_precision_score
|
||
def eval_imbalanced(probs, targets):
|
||
return {
|
||
'pr_auc': average_precision_score(targets, probs),
|
||
# 매 not accuracy — meaningless for imbalanced
|
||
}
|
||
```
|
||
|
||
## 매 결정 기준
|
||
| 상황 | Loss |
|
||
|---|---|
|
||
| Dense object detection | Focal (γ=2) |
|
||
| Mild imbalance | Class-weighted CE |
|
||
| Extreme imbalance | Focal (γ=3-5) |
|
||
| Multi-label imbalanced | Asymmetric focal |
|
||
| Segmentation | Focal + Dice |
|
||
| Easy task | Plain CE |
|
||
|
||
**기본값**: 매 imbalance > 100:1 → focal γ=2 + α=0.25. 매 multi-label → asymmetric. 매 evaluation = PR-AUC, not accuracy.
|
||
|
||
## 🔗 Graph
|
||
- 부모: [[Loss-Function]] · [[Object-Detection]]
|
||
- 응용: [[RetinaNet]]
|
||
- Adjacent: [[Cross-Entropy-Loss]]
|
||
|
||
## 🤖 LLM 활용
|
||
**언제**: 매 detection. 매 imbalanced. 매 medical.
|
||
**언제 X**: 매 balanced. 매 regression.
|
||
|
||
## ❌ 안티패턴
|
||
- **Focal for balanced**: 매 underfit.
|
||
- **γ too high**: 매 most examples 의 ignore.
|
||
- **No α**: 매 class balance 의 still off.
|
||
- **Accuracy metric**: 매 imbalanced 의 misleading.
|
||
|
||
## 🧪 검증 / 중복
|
||
- Verified (Lin RetinaNet 2017, Cui 2019, Ben-Baruch ASL 2020).
|
||
- 신뢰도 A.
|
||
|
||
## 🕓 Changelog
|
||
| 날짜 | 변경 |
|
||
|---|---|
|
||
| 2026-04-26 | LOSS-002 auto |
|
||
| 2026-05-08 | Phase 1 |
|
||
| 2026-05-10 | Manual cleanup — focal + 매 binary / multi / asymmetric / class-balanced code |
|