Files
2nd/10_Wiki/Topics/AI_and_ML/Focal-Loss.md
T
Antigravity Agent f8b21af4be Wiki cleanup: error-doc removal, dedup merge, link normalization
10_Wiki/Topics 대규모 정리:
- 오류 캡처/미완성 stub 문서 227개 제거
- 교차폴더 중복 43클러스터 병합 (63파일 → redirect)
- 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건
- 카테고리 MOC 6개 신규 생성
- Graph 섹션 미해결 related-keyword 링크 10,058건 제거

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 23:52:15 +09:00

232 lines
6.9 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
---
id: wiki-2026-0508-focal-loss
title: Focal Loss
category: 10_Wiki/Topics
status: verified
canonical_id: self
aliases: [focal loss, class imbalance, RetinaNet, hard example mining, dense detection]
duplicate_of: none
source_trust_level: A
confidence_score: 0.97
verification_status: applied
tags: [deep-learning, loss-function, focal-loss, class-imbalance, object-detection, retinanet]
raw_sources: []
last_reinforced: 2026-05-10
github_commit: pending
tech_stack:
language: Python
framework: PyTorch / TensorFlow
---
# Focal Loss
## 매 한 줄
> **"매 cross-entropy 의 의 의 의 (1-p_t)^γ factor"**. Lin 2017 (RetinaNet). 매 easy example 의 down-weight + 매 hard example 의 focus. 매 dense object detection 의 enable. 매 modern: 매 detection, 매 imbalanced classification, 매 instance segmentation.
## 매 핵심
### 매 formula
```
FL(p_t) = -α_t (1 - p_t)^γ log(p_t)
```
- p_t: 매 predicted prob 의 true class.
- γ: 매 focusing parameter (typically 2).
- α: 매 class weight (balance).
### 매 motivation
- 매 dense detection: 매 1000s neg vs 10s pos.
- 매 CE 의 easy negs 의 dominate.
- 매 (1-p_t)^γ 의 의 well-classified 의 의 의 weight ↓.
### 매 응용
1. **Object detection** (RetinaNet, FCOS, ATSS).
2. **Instance segmentation**.
3. **Imbalanced classification**.
4. **Medical** (rare disease).
5. **Fraud / anomaly**.
### 매 alternative
- **Class weights** (CE).
- **Online Hard Example Mining (OHEM)**.
- **Balanced sampling**.
- **Asymmetric loss** (multi-label).
- **Class-balanced loss** (Cui 2019).
## 💻 패턴
### Focal loss (binary)
```python
import torch
import torch.nn.functional as F
def focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0):
p = torch.sigmoid(logits)
p_t = p * targets + (1 - p) * (1 - targets)
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
focal = alpha_t * (1 - p_t).pow(gamma) * ce
return focal.mean()
```
### Focal loss (multi-class)
```python
def focal_loss_multi(logits, targets, gamma=2.0, alpha=None):
"""logits: [B, C], targets: [B]."""
log_p = F.log_softmax(logits, dim=-1)
log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze(1)
pt = log_pt.exp()
focal = -((1 - pt) ** gamma) * log_pt
if alpha is not None:
focal = alpha[targets] * focal
return focal.mean()
```
### As nn.Module
```python
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets):
return focal_loss_binary(logits, targets, self.alpha, self.gamma)
```
### RetinaNet detection
```python
class RetinaNetLoss(torch.nn.Module):
def __init__(self, n_classes, gamma=2.0, alpha=0.25):
super().__init__()
self.classification = FocalLoss(alpha=alpha, gamma=gamma)
self.regression = torch.nn.SmoothL1Loss()
def forward(self, cls_logits, box_preds, cls_targets, box_targets, fg_mask):
cls_loss = self.classification(cls_logits, cls_targets)
n_fg = fg_mask.sum().clamp(min=1)
box_loss = self.regression(box_preds[fg_mask], box_targets[fg_mask])
return cls_loss + box_loss / n_fg
```
### Class-balanced focal (Cui 2019)
```python
def class_balanced_focal(logits, targets, samples_per_class, beta=0.999, gamma=2.0):
"""매 effective number per class."""
eff_num = 1 - beta ** samples_per_class
weights = (1 - beta) / eff_num
weights = weights / weights.sum() * len(samples_per_class)
log_p = F.log_softmax(logits, -1)
log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze()
pt = log_pt.exp()
return (-weights[targets] * (1 - pt) ** gamma * log_pt).mean()
```
### Asymmetric focal (multi-label)
```python
def asymmetric_focal(logits, targets, gamma_pos=1, gamma_neg=4, clip=0.05):
"""매 Ben-Baruch 2020. 매 asymmetric γ."""
p = torch.sigmoid(logits)
p_neg = (p - clip).clamp(min=0)
pos_loss = targets * (1 - p) ** gamma_pos * torch.log(p.clamp(min=1e-8))
neg_loss = (1 - targets) * p_neg ** gamma_neg * torch.log((1 - p).clamp(min=1e-8))
return -(pos_loss + neg_loss).mean()
```
### γ tuning
```python
# 매 typical γ
GAMMA_VALUES = {
'extreme_imbalance': 5.0,
'object_detection': 2.0, # RetinaNet
'medical_rare_disease': 3.0,
'mild_imbalance': 1.0,
'no_imbalance': 0, # 매 = CE
}
```
### Focal vs CE comparison
```python
def compare_losses(p_easy=0.95, p_hard=0.5):
"""매 easy 매 hard 의 contribution."""
ce_easy = -torch.log(torch.tensor(p_easy))
ce_hard = -torch.log(torch.tensor(p_hard))
fl_easy = (1 - p_easy) ** 2 * ce_easy
fl_hard = (1 - p_hard) ** 2 * ce_hard
print(f'CE: easy {ce_easy:.3f}, hard {ce_hard:.3f}, ratio {ce_hard/ce_easy:.1f}x')
print(f'Focal γ=2: easy {fl_easy:.4f}, hard {fl_hard:.3f}, ratio {fl_hard/fl_easy:.0f}x')
# 매 focal 의 hard 의 의 의 의 dominate
```
### OHEM alternative
```python
def ohem_loss(logits, targets, ratio=0.25):
"""매 매 batch 의 매 hardest examples만 keep."""
losses = F.cross_entropy(logits, targets, reduction='none')
n_keep = int(len(losses) * ratio)
topk = losses.topk(n_keep)[0]
return topk.mean()
```
### Combine: focal + dice (segmentation)
```python
def focal_dice(logits, targets):
fl = focal_loss_binary(logits, targets)
p = torch.sigmoid(logits)
intersection = (p * targets).sum()
dice = 1 - 2 * intersection / (p.sum() + targets.sum() + 1)
return fl + dice
```
### Eval metric (PR-AUC for imbalanced)
```python
from sklearn.metrics import average_precision_score
def eval_imbalanced(probs, targets):
return {
'pr_auc': average_precision_score(targets, probs),
# 매 not accuracy — meaningless for imbalanced
}
```
## 매 결정 기준
| 상황 | Loss |
|---|---|
| Dense object detection | Focal (γ=2) |
| Mild imbalance | Class-weighted CE |
| Extreme imbalance | Focal (γ=3-5) |
| Multi-label imbalanced | Asymmetric focal |
| Segmentation | Focal + Dice |
| Easy task | Plain CE |
**기본값**: 매 imbalance > 100:1 → focal γ=2 + α=0.25. 매 multi-label → asymmetric. 매 evaluation = PR-AUC, not accuracy.
## 🔗 Graph
- 부모: [[Loss-Function]] · [[Object-Detection]]
- 응용: [[RetinaNet]]
- Adjacent: [[Cross-Entropy-Loss]]
## 🤖 LLM 활용
**언제**: 매 detection. 매 imbalanced. 매 medical.
**언제 X**: 매 balanced. 매 regression.
## ❌ 안티패턴
- **Focal for balanced**: 매 underfit.
- **γ too high**: 매 most examples 의 ignore.
- **No α**: 매 class balance 의 still off.
- **Accuracy metric**: 매 imbalanced 의 misleading.
## 🧪 검증 / 중복
- Verified (Lin RetinaNet 2017, Cui 2019, Ben-Baruch ASL 2020).
- 신뢰도 A.
## 🕓 Changelog
| 날짜 | 변경 |
|---|---|
| 2026-04-26 | LOSS-002 auto |
| 2026-05-08 | Phase 1 |
| 2026-05-10 | Manual cleanup — focal + 매 binary / multi / asymmetric / class-balanced code |