Notice
Recent Posts
Recent Comments
Link
«   2025/01   »
1 2 3 4
5 6 7 8 9 10 11
12 13 14 15 16 17 18
19 20 21 22 23 24 25
26 27 28 29 30 31
Archives
Today
Total
관리 메뉴

고양이는 털털해

Semantic Segmentation : Focal Loss 본문

공부/SSACxAIFFEL

Semantic Segmentation : Focal Loss

Attagungho 2021. 6. 23. 00:28

focal_loss_blog

Focal Loss

1. 정리

  • focal loss는 데이터가 불균형적인 특징이 존재하는 semantic segmentation에서 많이 사용하는 loss 함수 이다.
    • 모델이 정답을 확신하며 추측한 경우 그에 대한 비중은 줄이고 확신하지 못하는 경우의 비중을 키워서 잘 학습한 내용보다 잘 학습하지 못한 내용에 집중하도록 만들어진 loss 함수이다.
    • 달리 말하면 불균형 데이터셋에서 빈도가 적은 데이터에 비중을 조절하면서 모델이 확신을 갖지 못하는 항목에 대해 좀 더 모험적인 판단을 하도록 하여, 데이터 셋에서 많이 보지 못한 데이터에 대해 좀 더 잘 학습하도록 짜여진 loss 함수이다.
    • focal loss의 수식은 아래와 같다.
      $$Focal\ Loss\ = -\alpha_t(1-p_t)^\gamma log(p_t)$$
  • sm-segmentation 오픈라이브러리에서 focal loss 는 binary focal loss와 categorical focal loss로 짜여져 있다.

목차로


2. Focal Loss 함수

$$Focal\ Loss\ = -\alpha_t(1-p_t)^\gamma log(p_t)$$

  • $\alpha$앞의 (-)항은 cross entropy loss에서와 같다. log 함수를 loss 함수로 사용하고 $p_t$값이 0~1 사이의 값이기에 log 계산시 음의 값이 나오는 것을 부호를 바꿔 정답에 가까운 예측을 했을 때 크기가 점점 줄어들어 0이 되는 loss 함수로 쓰기 위함이다.
    • 그림으로 그려보면 아래와 같다. 만약 (-)가 없다면 1을 향해 점점 커지는 함수가 될 것이다.
In [1]:
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0,1,500)
y = -np.log(x)
plt.figure(figsize=(10,5))
plt.plot(x,y)
plt.show()
<ipython-input-1-ea8e754892e8>:4: RuntimeWarning: divide by zero encountered in log
  y = -np.log(x)
  • $\alpha_t$는 focal loss의 고유한 아이디어는 아닌 것으로 보인다. 불균형 데이터의 계산 상 비중을 조정하기 위한 가중치를 뜻하는 표현이다. 흔하지 않은 데이터에 큰 비중을 두어 계산한 loss 값에서의 비중을 조정해 주기 위한 표현으로 볼 수 있다.
  • $p_t$는 모델이 정답을 확신하는 정도롤 뜻한다고 이해할 수 있을 것이다. binary 모델 특성 상 정답이 1인 데이터에 대한 모델의 output 값이 0.7이라면 모델은 0.7 만큼 정답을 확신한다고 이해할 수 있다. 반면 정답이 0인 데이터에 대한 모델의 output 값이 0.7이면 정답이 0일 것에 대한 모델의 확신은 1-0.7인 0.3 이라고 볼 수 있을 것이다. 모델은 해당 데이터가 0.7의 확률로 1일 것이라고 예측했다고 볼 수 있으니 실제 정답인 0은 0.3의 확률로 정답일 것이라고 예측했다고 볼 수 있다. binary 문제이기에 이런 계산을 해볼 수 있는 것 같다.
  • $\gamma$는 focal loss에서 모델이 잘 예측한 것 보다는 잘 예측하지 못한 것에 집중하도록 만들어주는 표현이라고도 볼 수 있을 것 같다. $\gamma$가 0이라면 $(1-p_t)^\gamma$항이 1이 되어 binary cross entropy 와 동일한 loss함수가 되겠지만 이 값이 0보다 크다면 cross entropy 함수 보다 좀 더 빠르게 감소하는 함수가 된다.
    • $\gamma$값에 대한 focal loss 함수의 개형은 아래 그림으로 확인할 수 있다.
In [2]:
from IPython.display import Image
Image("FL_v_CE.png")
Out[2]:

이미지 출처 : Focal Loss for Dense Object Detection : https://arxiv.org/abs/1708.02002

  • cross entropy loss는 모델 예측 값이 0.6일 때의 loss가 0.5 근처로 loss를 더 줄이기 위해 정답이 1일 것을 좀 더 높은 확률로 예측할 유인이 남아 있다. 어떻게 보면 cross entropy loss는 모델이 정답일 것을 매우 확신하면서 예측하도록 유도하는 측면이 있다.
  • 반면 $\gamma$가 2인 focal loss 에서는 모델 예측 값이 0.6일 때의 loss는 0에 충분히 가까워 모델이 해당 예측에 대해 좀 더 확신을 갖도록 하게 하는 유인이 적으며 모델이 분류하기 어려워 확신을 갖지 못한 0.5 이하의 구간에서의 loss함수 기울기를 급격하게 만들어 이러한 문제들에 좀 더 집중하도록 격려하는 것으로 이해할 수 있을 것이다.

목차로


3. sm-segmentation에서의 계산식 확인: binary, categorical

def categorical_focal_loss(gt, pr, gamma=2.0, alpha=0.25, class_indexes=None, **kwargs):
    r"""Implementation of Focal Loss from the paper in multiclass classification
    Formula:
        loss = - gt * alpha * ((1 - pr)^gamma) * log(pr)
    Args:
        gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
        pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
        alpha: the same as weighting factor in balanced cross entropy, default 0.25
        gamma: focusing parameter for modulating factor (1-p), default 2.0
        class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
    """

    backend = kwargs['backend']
    gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)

    # clip to prevent NaN's and Inf's
    pr = backend.clip(pr, backend.epsilon(), 1.0 - backend.epsilon())

    # Calculate focal loss
    loss = - gt * (alpha * backend.pow((1 - pr), gamma) * backend.log(pr))

    return backend.mean(loss)


def binary_focal_loss(gt, pr, gamma=2.0, alpha=0.25, **kwargs):
    r"""Implementation of Focal Loss from the paper in binary classification
    Formula:
        loss = - gt * alpha * ((1 - pr)^gamma) * log(pr) \
               - (1 - gt) * alpha * (pr^gamma) * log(1 - pr)
    Args:
        gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
        pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
        alpha: the same as weighting factor in balanced cross entropy, default 0.25
        gamma: focusing parameter for modulating factor (1-p), default 2.0
    """
    backend = kwargs['backend']

    # clip to prevent NaN's and Inf's
    pr = backend.clip(pr, backend.epsilon(), 1.0 - backend.epsilon())

    loss_1 = - gt * (alpha * backend.pow((1 - pr), gamma) * backend.log(pr))
    loss_0 = - (1 - gt) * ((1 - alpha) * backend.pow((pr), gamma) * backend.log(1 - pr))
    loss = backend.mean(loss_0 + loss_1)
    return loss
  • binary focal loss의 formula를 먼저 보면 gt와 (1-gt)항을 곱한 것이 추가 되어 있는 것을 확인할 수 있다. ground truth가 1일 경우 (1-gt)가 0이 되어 아래 식이 사라지며 ground truth가 0일 경우 위의 식이 사라진다. 그에 맞게 log 안을 예측값으로 정답일 것을 확신하는 정도로 계산되도록 맞춰주어 위에서 살펴본 focal loss를 구현하고 있음을 확인할 수 있다.
  • categorical focal loss의 formula를 살펴보면 binary focal loss보다 짧은 식이 쓰여져 있는데 단순히 ground truth만을 곱하고 모델의 예측값을 그대로 log에 넣는 것을 확인할 수 있다. 이러한 차이는 gt의 텐서 차원과 binary와 multi class의 task의 차이를 생각해 보면 이해할 수 있다.
    • binary task의 경우 channel은 1로 모든 픽셀에 대해 정답일 것을 확신하는 정도는 정답이 0인 경우와 1인 경우를 예측값을 조정해서 살펴봐야 한다.
    • multiclass의 경우 gt가 (Batch, Height, Width, Channel)의 4차원 텐서인데 분류하고자 하는 class 만큼의 channel을 갖게 될 것이다. 한 픽셀의 모든 채널을 한번에 살펴보게 되므로 gt를 곱하고 예측값을 그대로 계산하면 gt가 0인 채널은 정답이 아니고 해당 채널의 gt 0을 곱한 계산 값은 0이 되어 무시할 수 있게되고, gt가 1인 채널의 예측값은 그대로 정답일 것을 확신하는 정도로 이해할 수 있게 된다.

목차로