f8b21af4be
10_Wiki/Topics 대규모 정리: - 오류 캡처/미완성 stub 문서 227개 제거 - 교차폴더 중복 43클러스터 병합 (63파일 → redirect) - 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건 - 카테고리 MOC 6개 신규 생성 - Graph 섹션 미해결 related-keyword 링크 10,058건 제거 Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
6.9 KiB
6.9 KiB
id, title, category, status, canonical_id, aliases, duplicate_of, source_trust_level, confidence_score, verification_status, tags, raw_sources, last_reinforced, github_commit, tech_stack
| id | title | category | status | canonical_id | aliases | duplicate_of | source_trust_level | confidence_score | verification_status | tags | raw_sources | last_reinforced | github_commit | tech_stack | |||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| wiki-2026-0508-focal-loss | Focal Loss | 10_Wiki/Topics | verified | self |
|
none | A | 0.97 | applied |
|
2026-05-10 | pending |
|
Focal Loss
매 한 줄
"매 cross-entropy 의 의 의 의 (1-p_t)^γ factor". Lin 2017 (RetinaNet). 매 easy example 의 down-weight + 매 hard example 의 focus. 매 dense object detection 의 enable. 매 modern: 매 detection, 매 imbalanced classification, 매 instance segmentation.
매 핵심
매 formula
FL(p_t) = -α_t (1 - p_t)^γ log(p_t)
- p_t: 매 predicted prob 의 true class.
- γ: 매 focusing parameter (typically 2).
- α: 매 class weight (balance).
매 motivation
- 매 dense detection: 매 1000s neg vs 10s pos.
- 매 CE 의 easy negs 의 dominate.
- 매 (1-p_t)^γ 의 의 well-classified 의 의 의 weight ↓.
매 응용
- Object detection (RetinaNet, FCOS, ATSS).
- Instance segmentation.
- Imbalanced classification.
- Medical (rare disease).
- Fraud / anomaly.
매 alternative
- Class weights (CE).
- Online Hard Example Mining (OHEM).
- Balanced sampling.
- Asymmetric loss (multi-label).
- Class-balanced loss (Cui 2019).
💻 패턴
Focal loss (binary)
import torch
import torch.nn.functional as F
def focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0):
p = torch.sigmoid(logits)
p_t = p * targets + (1 - p) * (1 - targets)
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
focal = alpha_t * (1 - p_t).pow(gamma) * ce
return focal.mean()
Focal loss (multi-class)
def focal_loss_multi(logits, targets, gamma=2.0, alpha=None):
"""logits: [B, C], targets: [B]."""
log_p = F.log_softmax(logits, dim=-1)
log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze(1)
pt = log_pt.exp()
focal = -((1 - pt) ** gamma) * log_pt
if alpha is not None:
focal = alpha[targets] * focal
return focal.mean()
As nn.Module
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super().__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, logits, targets):
return focal_loss_binary(logits, targets, self.alpha, self.gamma)
RetinaNet detection
class RetinaNetLoss(torch.nn.Module):
def __init__(self, n_classes, gamma=2.0, alpha=0.25):
super().__init__()
self.classification = FocalLoss(alpha=alpha, gamma=gamma)
self.regression = torch.nn.SmoothL1Loss()
def forward(self, cls_logits, box_preds, cls_targets, box_targets, fg_mask):
cls_loss = self.classification(cls_logits, cls_targets)
n_fg = fg_mask.sum().clamp(min=1)
box_loss = self.regression(box_preds[fg_mask], box_targets[fg_mask])
return cls_loss + box_loss / n_fg
Class-balanced focal (Cui 2019)
def class_balanced_focal(logits, targets, samples_per_class, beta=0.999, gamma=2.0):
"""매 effective number per class."""
eff_num = 1 - beta ** samples_per_class
weights = (1 - beta) / eff_num
weights = weights / weights.sum() * len(samples_per_class)
log_p = F.log_softmax(logits, -1)
log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze()
pt = log_pt.exp()
return (-weights[targets] * (1 - pt) ** gamma * log_pt).mean()
Asymmetric focal (multi-label)
def asymmetric_focal(logits, targets, gamma_pos=1, gamma_neg=4, clip=0.05):
"""매 Ben-Baruch 2020. 매 asymmetric γ."""
p = torch.sigmoid(logits)
p_neg = (p - clip).clamp(min=0)
pos_loss = targets * (1 - p) ** gamma_pos * torch.log(p.clamp(min=1e-8))
neg_loss = (1 - targets) * p_neg ** gamma_neg * torch.log((1 - p).clamp(min=1e-8))
return -(pos_loss + neg_loss).mean()
γ tuning
# 매 typical γ
GAMMA_VALUES = {
'extreme_imbalance': 5.0,
'object_detection': 2.0, # RetinaNet
'medical_rare_disease': 3.0,
'mild_imbalance': 1.0,
'no_imbalance': 0, # 매 = CE
}
Focal vs CE comparison
def compare_losses(p_easy=0.95, p_hard=0.5):
"""매 easy 매 hard 의 contribution."""
ce_easy = -torch.log(torch.tensor(p_easy))
ce_hard = -torch.log(torch.tensor(p_hard))
fl_easy = (1 - p_easy) ** 2 * ce_easy
fl_hard = (1 - p_hard) ** 2 * ce_hard
print(f'CE: easy {ce_easy:.3f}, hard {ce_hard:.3f}, ratio {ce_hard/ce_easy:.1f}x')
print(f'Focal γ=2: easy {fl_easy:.4f}, hard {fl_hard:.3f}, ratio {fl_hard/fl_easy:.0f}x')
# 매 focal 의 hard 의 의 의 의 dominate
OHEM alternative
def ohem_loss(logits, targets, ratio=0.25):
"""매 매 batch 의 매 hardest examples만 keep."""
losses = F.cross_entropy(logits, targets, reduction='none')
n_keep = int(len(losses) * ratio)
topk = losses.topk(n_keep)[0]
return topk.mean()
Combine: focal + dice (segmentation)
def focal_dice(logits, targets):
fl = focal_loss_binary(logits, targets)
p = torch.sigmoid(logits)
intersection = (p * targets).sum()
dice = 1 - 2 * intersection / (p.sum() + targets.sum() + 1)
return fl + dice
Eval metric (PR-AUC for imbalanced)
from sklearn.metrics import average_precision_score
def eval_imbalanced(probs, targets):
return {
'pr_auc': average_precision_score(targets, probs),
# 매 not accuracy — meaningless for imbalanced
}
매 결정 기준
| 상황 | Loss |
|---|---|
| Dense object detection | Focal (γ=2) |
| Mild imbalance | Class-weighted CE |
| Extreme imbalance | Focal (γ=3-5) |
| Multi-label imbalanced | Asymmetric focal |
| Segmentation | Focal + Dice |
| Easy task | Plain CE |
기본값: 매 imbalance > 100:1 → focal γ=2 + α=0.25. 매 multi-label → asymmetric. 매 evaluation = PR-AUC, not accuracy.
🔗 Graph
- 부모: Loss-Function · Object-Detection
- 응용: RetinaNet
- Adjacent: Cross-Entropy-Loss
🤖 LLM 활용
언제: 매 detection. 매 imbalanced. 매 medical. 언제 X: 매 balanced. 매 regression.
❌ 안티패턴
- Focal for balanced: 매 underfit.
- γ too high: 매 most examples 의 ignore.
- No α: 매 class balance 의 still off.
- Accuracy metric: 매 imbalanced 의 misleading.
🧪 검증 / 중복
- Verified (Lin RetinaNet 2017, Cui 2019, Ben-Baruch ASL 2020).
- 신뢰도 A.
🕓 Changelog
| 날짜 | 변경 |
|---|---|
| 2026-04-26 | LOSS-002 auto |
| 2026-05-08 | Phase 1 |
| 2026-05-10 | Manual cleanup — focal + 매 binary / multi / asymmetric / class-balanced code |