Notice
Recent Posts
Recent Comments
Link
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | |||
5 | 6 | 7 | 8 | 9 | 10 | 11 |
12 | 13 | 14 | 15 | 16 | 17 | 18 |
19 | 20 | 21 | 22 | 23 | 24 | 25 |
26 | 27 | 28 | 29 | 30 | 31 |
Tags
- SSAC
- loss function
- segmentation
- Dice Loss
- sm-segmentation
- focal loss
- Satellite Image
- Satel
- LightGBM
- 시계열
Archives
- Today
- Total
고양이는 털털해
Semantic Segmentation : Focal Loss 본문
Focal Loss¶
1. 정리¶
- focal loss는 데이터가 불균형적인 특징이 존재하는 semantic segmentation에서 많이 사용하는 loss 함수 이다.
- 모델이 정답을 확신하며 추측한 경우 그에 대한 비중은 줄이고 확신하지 못하는 경우의 비중을 키워서 잘 학습한 내용보다 잘 학습하지 못한 내용에 집중하도록 만들어진 loss 함수이다.
- 달리 말하면 불균형 데이터셋에서 빈도가 적은 데이터에 비중을 조절하면서 모델이 확신을 갖지 못하는 항목에 대해 좀 더 모험적인 판단을 하도록 하여, 데이터 셋에서 많이 보지 못한 데이터에 대해 좀 더 잘 학습하도록 짜여진 loss 함수이다.
- focal loss의 수식은 아래와 같다.
$$Focal\ Loss\ = -\alpha_t(1-p_t)^\gamma log(p_t)$$
- sm-segmentation 오픈라이브러리에서 focal loss 는 binary focal loss와 categorical focal loss로 짜여져 있다.
2. Focal Loss 함수¶
$$Focal\ Loss\ = -\alpha_t(1-p_t)^\gamma log(p_t)$$- $\alpha$앞의 (-)항은 cross entropy loss에서와 같다. log 함수를 loss 함수로 사용하고 $p_t$값이 0~1 사이의 값이기에 log 계산시 음의 값이 나오는 것을 부호를 바꿔 정답에 가까운 예측을 했을 때 크기가 점점 줄어들어 0이 되는 loss 함수로 쓰기 위함이다.
- 그림으로 그려보면 아래와 같다. 만약 (-)가 없다면 1을 향해 점점 커지는 함수가 될 것이다.
In [1]:
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0,1,500)
y = -np.log(x)
plt.figure(figsize=(10,5))
plt.plot(x,y)
plt.show()
- $\alpha_t$는 focal loss의 고유한 아이디어는 아닌 것으로 보인다. 불균형 데이터의 계산 상 비중을 조정하기 위한 가중치를 뜻하는 표현이다. 흔하지 않은 데이터에 큰 비중을 두어 계산한 loss 값에서의 비중을 조정해 주기 위한 표현으로 볼 수 있다.
- $p_t$는 모델이 정답을 확신하는 정도롤 뜻한다고 이해할 수 있을 것이다. binary 모델 특성 상 정답이 1인 데이터에 대한 모델의 output 값이 0.7이라면 모델은 0.7 만큼 정답을 확신한다고 이해할 수 있다. 반면 정답이 0인 데이터에 대한 모델의 output 값이 0.7이면 정답이 0일 것에 대한 모델의 확신은 1-0.7인 0.3 이라고 볼 수 있을 것이다. 모델은 해당 데이터가 0.7의 확률로 1일 것이라고 예측했다고 볼 수 있으니 실제 정답인 0은 0.3의 확률로 정답일 것이라고 예측했다고 볼 수 있다. binary 문제이기에 이런 계산을 해볼 수 있는 것 같다.
- $\gamma$는 focal loss에서 모델이 잘 예측한 것 보다는 잘 예측하지 못한 것에 집중하도록 만들어주는 표현이라고도 볼 수 있을 것 같다. $\gamma$가 0이라면 $(1-p_t)^\gamma$항이 1이 되어 binary cross entropy 와 동일한 loss함수가 되겠지만 이 값이 0보다 크다면 cross entropy 함수 보다 좀 더 빠르게 감소하는 함수가 된다.
- $\gamma$값에 대한 focal loss 함수의 개형은 아래 그림으로 확인할 수 있다.
In [2]:
from IPython.display import Image
Image("FL_v_CE.png")
Out[2]:
이미지 출처 : Focal Loss for Dense Object Detection : https://arxiv.org/abs/1708.02002
- cross entropy loss는 모델 예측 값이 0.6일 때의 loss가 0.5 근처로 loss를 더 줄이기 위해 정답이 1일 것을 좀 더 높은 확률로 예측할 유인이 남아 있다. 어떻게 보면 cross entropy loss는 모델이 정답일 것을 매우 확신하면서 예측하도록 유도하는 측면이 있다.
- 반면 $\gamma$가 2인 focal loss 에서는 모델 예측 값이 0.6일 때의 loss는 0에 충분히 가까워 모델이 해당 예측에 대해 좀 더 확신을 갖도록 하게 하는 유인이 적으며 모델이 분류하기 어려워 확신을 갖지 못한 0.5 이하의 구간에서의 loss함수 기울기를 급격하게 만들어 이러한 문제들에 좀 더 집중하도록 격려하는 것으로 이해할 수 있을 것이다.
3. sm-segmentation에서의 계산식 확인: binary, categorical¶
def categorical_focal_loss(gt, pr, gamma=2.0, alpha=0.25, class_indexes=None, **kwargs):
r"""Implementation of Focal Loss from the paper in multiclass classification
Formula:
loss = - gt * alpha * ((1 - pr)^gamma) * log(pr)
Args:
gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
alpha: the same as weighting factor in balanced cross entropy, default 0.25
gamma: focusing parameter for modulating factor (1-p), default 2.0
class_indexes: Optional integer or list of integers, classes to consider, if ``None`` all classes are used.
"""
backend = kwargs['backend']
gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)
# clip to prevent NaN's and Inf's
pr = backend.clip(pr, backend.epsilon(), 1.0 - backend.epsilon())
# Calculate focal loss
loss = - gt * (alpha * backend.pow((1 - pr), gamma) * backend.log(pr))
return backend.mean(loss)
def binary_focal_loss(gt, pr, gamma=2.0, alpha=0.25, **kwargs):
r"""Implementation of Focal Loss from the paper in binary classification
Formula:
loss = - gt * alpha * ((1 - pr)^gamma) * log(pr) \
- (1 - gt) * alpha * (pr^gamma) * log(1 - pr)
Args:
gt: ground truth 4D keras tensor (B, H, W, C) or (B, C, H, W)
pr: prediction 4D keras tensor (B, H, W, C) or (B, C, H, W)
alpha: the same as weighting factor in balanced cross entropy, default 0.25
gamma: focusing parameter for modulating factor (1-p), default 2.0
"""
backend = kwargs['backend']
# clip to prevent NaN's and Inf's
pr = backend.clip(pr, backend.epsilon(), 1.0 - backend.epsilon())
loss_1 = - gt * (alpha * backend.pow((1 - pr), gamma) * backend.log(pr))
loss_0 = - (1 - gt) * ((1 - alpha) * backend.pow((pr), gamma) * backend.log(1 - pr))
loss = backend.mean(loss_0 + loss_1)
return loss
- binary focal loss의 formula를 먼저 보면 gt와 (1-gt)항을 곱한 것이 추가 되어 있는 것을 확인할 수 있다. ground truth가 1일 경우 (1-gt)가 0이 되어 아래 식이 사라지며 ground truth가 0일 경우 위의 식이 사라진다. 그에 맞게 log 안을 예측값으로 정답일 것을 확신하는 정도로 계산되도록 맞춰주어 위에서 살펴본 focal loss를 구현하고 있음을 확인할 수 있다.
- categorical focal loss의 formula를 살펴보면 binary focal loss보다 짧은 식이 쓰여져 있는데 단순히 ground truth만을 곱하고 모델의 예측값을 그대로 log에 넣는 것을 확인할 수 있다. 이러한 차이는 gt의 텐서 차원과 binary와 multi class의 task의 차이를 생각해 보면 이해할 수 있다.
- binary task의 경우 channel은 1로 모든 픽셀에 대해 정답일 것을 확신하는 정도는 정답이 0인 경우와 1인 경우를 예측값을 조정해서 살펴봐야 한다.
- multiclass의 경우 gt가 (Batch, Height, Width, Channel)의 4차원 텐서인데 분류하고자 하는 class 만큼의 channel을 갖게 될 것이다. 한 픽셀의 모든 채널을 한번에 살펴보게 되므로 gt를 곱하고 예측값을 그대로 계산하면 gt가 0인 채널은 정답이 아니고 해당 채널의 gt 0을 곱한 계산 값은 0이 되어 무시할 수 있게되고, gt가 1인 채널의 예측값은 그대로 정답일 것을 확신하는 정도로 이해할 수 있게 된다.
'공부 > SSACxAIFFEL' 카테고리의 다른 글
lightgbm 하이퍼 파라미터 (0) | 2021.06.21 |
---|---|
Semantic Segmentation : DICE LOSS (0) | 2021.06.21 |
SSACxAIFFEL 10주차: 21'03.08. ~ 03.12; 전이학습, keras (2) | 2021.03.12 |
SSACxAIFFEL 5주차 : 21' 02.01 ~ 02.05 (0) | 2021.02.05 |
SSACxAIFFEL 4주차 : 21' 01.25 ~ 01.29 (0) | 2021.01.31 |