Files
2nd/10_Wiki/Topics/AI_and_ML/Focal-Loss.md
T
Antigravity Agent f8b21af4be Wiki cleanup: error-doc removal, dedup merge, link normalization
10_Wiki/Topics 대규모 정리:
- 오류 캡처/미완성 stub 문서 227개 제거
- 교차폴더 중복 43클러스터 병합 (63파일 → redirect)
- 링크명 정규화: 깨진 링크 수정·redirect 직결·개념 매핑 ~2,400건
- 카테고리 MOC 6개 신규 생성
- Graph 섹션 미해결 related-keyword 링크 10,058건 제거

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-20 23:52:15 +09:00

6.9 KiB
Raw Blame History

id, title, category, status, canonical_id, aliases, duplicate_of, source_trust_level, confidence_score, verification_status, tags, raw_sources, last_reinforced, github_commit, tech_stack
id title category status canonical_id aliases duplicate_of source_trust_level confidence_score verification_status tags raw_sources last_reinforced github_commit tech_stack
wiki-2026-0508-focal-loss Focal Loss 10_Wiki/Topics verified self
focal loss
class imbalance
RetinaNet
hard example mining
dense detection
none A 0.97 applied
deep-learning
loss-function
focal-loss
class-imbalance
object-detection
retinanet
2026-05-10 pending
language framework
Python PyTorch / TensorFlow

Focal Loss

매 한 줄

"매 cross-entropy 의 의 의 의 (1-p_t)^γ factor". Lin 2017 (RetinaNet). 매 easy example 의 down-weight + 매 hard example 의 focus. 매 dense object detection 의 enable. 매 modern: 매 detection, 매 imbalanced classification, 매 instance segmentation.

매 핵심

매 formula

FL(p_t) = -α_t (1 - p_t)^γ log(p_t)
  • p_t: 매 predicted prob 의 true class.
  • γ: 매 focusing parameter (typically 2).
  • α: 매 class weight (balance).

매 motivation

  • 매 dense detection: 매 1000s neg vs 10s pos.
  • 매 CE 의 easy negs 의 dominate.
  • 매 (1-p_t)^γ 의 의 well-classified 의 의 의 weight ↓.

매 응용

  1. Object detection (RetinaNet, FCOS, ATSS).
  2. Instance segmentation.
  3. Imbalanced classification.
  4. Medical (rare disease).
  5. Fraud / anomaly.

매 alternative

  • Class weights (CE).
  • Online Hard Example Mining (OHEM).
  • Balanced sampling.
  • Asymmetric loss (multi-label).
  • Class-balanced loss (Cui 2019).

💻 패턴

Focal loss (binary)

import torch
import torch.nn.functional as F

def focal_loss_binary(logits, targets, alpha=0.25, gamma=2.0):
    p = torch.sigmoid(logits)
    p_t = p * targets + (1 - p) * (1 - targets)
    alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
    ce = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
    focal = alpha_t * (1 - p_t).pow(gamma) * ce
    return focal.mean()

Focal loss (multi-class)

def focal_loss_multi(logits, targets, gamma=2.0, alpha=None):
    """logits: [B, C], targets: [B]."""
    log_p = F.log_softmax(logits, dim=-1)
    log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze(1)
    pt = log_pt.exp()
    
    focal = -((1 - pt) ** gamma) * log_pt
    if alpha is not None:
        focal = alpha[targets] * focal
    return focal.mean()

As nn.Module

class FocalLoss(torch.nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, logits, targets):
        return focal_loss_binary(logits, targets, self.alpha, self.gamma)

RetinaNet detection

class RetinaNetLoss(torch.nn.Module):
    def __init__(self, n_classes, gamma=2.0, alpha=0.25):
        super().__init__()
        self.classification = FocalLoss(alpha=alpha, gamma=gamma)
        self.regression = torch.nn.SmoothL1Loss()
    
    def forward(self, cls_logits, box_preds, cls_targets, box_targets, fg_mask):
        cls_loss = self.classification(cls_logits, cls_targets)
        n_fg = fg_mask.sum().clamp(min=1)
        box_loss = self.regression(box_preds[fg_mask], box_targets[fg_mask])
        return cls_loss + box_loss / n_fg

Class-balanced focal (Cui 2019)

def class_balanced_focal(logits, targets, samples_per_class, beta=0.999, gamma=2.0):
    """매 effective number per class."""
    eff_num = 1 - beta ** samples_per_class
    weights = (1 - beta) / eff_num
    weights = weights / weights.sum() * len(samples_per_class)
    
    log_p = F.log_softmax(logits, -1)
    log_pt = log_p.gather(1, targets.unsqueeze(1)).squeeze()
    pt = log_pt.exp()
    return (-weights[targets] * (1 - pt) ** gamma * log_pt).mean()

Asymmetric focal (multi-label)

def asymmetric_focal(logits, targets, gamma_pos=1, gamma_neg=4, clip=0.05):
    """매 Ben-Baruch 2020. 매 asymmetric γ."""
    p = torch.sigmoid(logits)
    p_neg = (p - clip).clamp(min=0)
    
    pos_loss = targets * (1 - p) ** gamma_pos * torch.log(p.clamp(min=1e-8))
    neg_loss = (1 - targets) * p_neg ** gamma_neg * torch.log((1 - p).clamp(min=1e-8))
    return -(pos_loss + neg_loss).mean()

γ tuning

# 매 typical γ
GAMMA_VALUES = {
    'extreme_imbalance': 5.0,
    'object_detection': 2.0,  # RetinaNet
    'medical_rare_disease': 3.0,
    'mild_imbalance': 1.0,
    'no_imbalance': 0,  # 매 = CE
}

Focal vs CE comparison

def compare_losses(p_easy=0.95, p_hard=0.5):
    """매 easy 매 hard 의 contribution."""
    ce_easy = -torch.log(torch.tensor(p_easy))
    ce_hard = -torch.log(torch.tensor(p_hard))
    
    fl_easy = (1 - p_easy) ** 2 * ce_easy
    fl_hard = (1 - p_hard) ** 2 * ce_hard
    
    print(f'CE: easy {ce_easy:.3f}, hard {ce_hard:.3f}, ratio {ce_hard/ce_easy:.1f}x')
    print(f'Focal γ=2: easy {fl_easy:.4f}, hard {fl_hard:.3f}, ratio {fl_hard/fl_easy:.0f}x')
    # 매 focal 의 hard 의 의 의 의 dominate

OHEM alternative

def ohem_loss(logits, targets, ratio=0.25):
    """매 매 batch 의 매 hardest examples만 keep."""
    losses = F.cross_entropy(logits, targets, reduction='none')
    n_keep = int(len(losses) * ratio)
    topk = losses.topk(n_keep)[0]
    return topk.mean()

Combine: focal + dice (segmentation)

def focal_dice(logits, targets):
    fl = focal_loss_binary(logits, targets)
    p = torch.sigmoid(logits)
    intersection = (p * targets).sum()
    dice = 1 - 2 * intersection / (p.sum() + targets.sum() + 1)
    return fl + dice

Eval metric (PR-AUC for imbalanced)

from sklearn.metrics import average_precision_score
def eval_imbalanced(probs, targets):
    return {
        'pr_auc': average_precision_score(targets, probs),
        # 매 not accuracy — meaningless for imbalanced
    }

매 결정 기준

상황 Loss
Dense object detection Focal (γ=2)
Mild imbalance Class-weighted CE
Extreme imbalance Focal (γ=3-5)
Multi-label imbalanced Asymmetric focal
Segmentation Focal + Dice
Easy task Plain CE

기본값: 매 imbalance > 100:1 → focal γ=2 + α=0.25. 매 multi-label → asymmetric. 매 evaluation = PR-AUC, not accuracy.

🔗 Graph

🤖 LLM 활용

언제: 매 detection. 매 imbalanced. 매 medical. 언제 X: 매 balanced. 매 regression.

안티패턴

  • Focal for balanced: 매 underfit.
  • γ too high: 매 most examples 의 ignore.
  • No α: 매 class balance 의 still off.
  • Accuracy metric: 매 imbalanced 의 misleading.

🧪 검증 / 중복

  • Verified (Lin RetinaNet 2017, Cui 2019, Ben-Baruch ASL 2020).
  • 신뢰도 A.

🕓 Changelog

날짜 변경
2026-04-26 LOSS-002 auto
2026-05-08 Phase 1
2026-05-10 Manual cleanup — focal + 매 binary / multi / asymmetric / class-balanced code