Notice
Recent Posts
Recent Comments
Link
«   2024/12   »
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30 31
Archives
Today
Total
관리 메뉴

고양이는 털털해

Semantic Segmentation : DICE LOSS 본문

공부/SSACxAIFFEL

Semantic Segmentation : DICE LOSS

Attagungho 2021. 6. 21. 00:15

dice_loss_blog

dice loss


목차

1. 정리

2. dice loss 함수

3. sm-segmentation에서의 계산식 확인: f1-score와 dice coefficient

4. multi class dice loss와 class weight

5. 참조 문헌


1. 정리

  • dice loss는 데이터가 불균형적인 특징이 존재하는 semantic segmentation에서 많이 사용하는 loss 함수 이다.
  • dice loss의 수식은 아래와 같다. $$Dice\ Loss\ = 1 - \frac{2 \sum {p \cdot \hat{p}}} {\sum{p^2} +\sum{\hat{p}^2}} $$
  • sm-segmentation 오픈라이브러리에서 dice loss 는 $f_1-score$ 로 계산한다.
  • dice loss 는 $1 - f_1\ score$와 같고 $f_1\ score$의 수식은 아래와 같다. $$f_1\ score\ = \frac{2}{\frac{1}{recall} + \frac{1}{precision}}\ = \frac{2\cdot precision \cdot recall}{precision+recall}$$
  • multi-class segmentation에서 class 별 loss 값의 비중을 맞춰주기 위해서는 주 관심사가 아니면서 큰 비중을 차지하는 class 예를 들어 background가 대부분인 데이터를 제거하여 데이터의 비중을 어느정도 맞춰주도록 전처리를 하거나 전체 데이터 셋에서 각 class가 차지하는 비중을 계산하여 그의 역수를 class weight로 주어서 loss함수를 계산하도록 해야 한다.

목차로


2. dice loss 함수

  • 수식 $$Dice\ Loss\ = 1 - \frac{2 \sum {p \cdot \hat{p}}} {\sum{p^2} +\sum{\hat{p}^2}} $$
  • p는 네트웍의 출력, $\hat{p}$ 는 정답 레이블이다.
    • segmentation에서 정답 레이블은 binary 값으로 주어지며 따라서 정답이 아닌 영역은 $\hat{p}$가 0으로 loss함수의 분자 계산에서 제외된다.
    • 정답이 0인 것을 0에 가깝게 예측하는 것은 분모의 $p^2$의 합을 줄여 분모를 감소시킬 수는 있으나 분자를 증가시키지는 않는다.
    • 정답이 0인 것을 1에 가깝게 예측하는 것은 분자는 증가시키지 않으면서 분모만을 증가시키므로 loss 함수를 비교적 크게 증가시킨다.
    • 정답인 영역을 제대로 맞추는 것을 크게 평가하는 loss함수 이므로 정답이 아닌 영역에 비해 정답인 영역이 적은 semantic segmentation에 자주 사용된다고 한다.

목차로


3. sm-segmentation에서의 계산식 확인: f1-score와 dice loss

  • sm-sgementation open library에서 diceloss 계산 코드를 보면 아래와 같다.
    {python}
    class DiceLoss(Loss):
      def __call__(self, gt, pr):
          return 1 - F.f_score(
              gt,
              pr,
              beta=self.beta,
              class_weights=self.class_weights,
              class_indexes=self.class_indexes,
              smooth=self.smooth,
              per_image=self.per_image,
              threshold=None,
              **self.submodules
          )
  • 위의 코드를 보면 1- f_score를 결과 값으로 반환하고 있다.

    • dice loss는 1-dice coefficient인데 dice coefficient는 f1-score와 동일하다고 한다.
  • dice coefficient의 수식은 아래와 같다. $$ DSC = \frac{2|A|\cap|B|}{|A|+|B|}$$

  • 분모와 분자의 계산을 그림으로 보면 아래와 같으며 dice loss의 분수의 수식과 dice coefficient의 수식이 동일한 것을 알 수 있다.
In [3]:
from IPython.display import Image
Image("intersection-1.png")
Out[3]:

분자 계산; 이미지 출처 : https://www.jeremyjordan.me/semantic-segmentation/

In [4]:
Image("denominator-1.png")
Out[4]:

분모 계산; 이미지 출처 : https://www.jeremyjordan.me/semantic-segmentation/

  • f1-score의 수식은 아래와 같다.
    $$f_1\ score\ = \frac{2}{\frac{1}{recall} + \frac{1}{precision}}\ = \frac{2\cdot precision \cdot recall}{precision+recall}$$
  • A를 예측, B를 타겟이라고 했을 때 precision 과 recall은 아래와 같이 나타낼 수 있다.
    $$precision\ = \frac{TP}{TP+FN}\ = \frac{|AB|}{|A|},\quad recall\ = \frac{TP}{TP+FN}\ = \frac{|AB|}{|B|}$$

    • precison은 정의 상 전체 예측한 |A| 중에 맞춘 것 |AB|, recall은 정의 상 정답|B| 중에 맞춘 것 |AB| 이므로 위와 같이 표현할 수 있다.
    • 이를 대입하여 f1-score의 수식을 풀어내면 아래와 같으며 DSC와 동일한 것을 확인할 수 있다. $$ f1-score\ = \frac{2|AB|}{|A|+|B|} DSC\ = \frac{2|A|\cap|B|}{|A|+|B|}$$
  • sm-segmentation f1-score의 계산 코드를 보면 아래와 같다.

    {python}
    def f_score(gt, pr, beta=1, class_weights=1, class_indexes=None, smooth=SMOOTH, per_image=False, threshold=None,
              **kwargs):
    
      backend = kwargs['backend']
    
      gt, pr = gather_channels(gt, pr, indexes=class_indexes, **kwargs)
      pr = round_if_needed(pr, threshold, **kwargs)
      axes = get_reduce_axes(per_image, **kwargs)
    
      # calculate score
      tp = backend.sum(gt * pr, axis=axes)
      fp = backend.sum(pr, axis=axes) - tp
      fn = backend.sum(gt, axis=axes) - tp
    
      score = ((1 + beta ** 2) * tp + smooth) \
              / ((1 + beta ** 2) * tp + beta ** 2 * fn + fp + smooth)
      score = average(score, per_image, class_weights, **kwargs)
    
      return score
  • $\beta$의 기본값인 1로 계산하면 아래와 같으며 f1-score의 수식과 동일한 것을 확인할 수 있다.
    $$ \frac{(1+\beta^2)\cdot TP}{(1+\beta^2)\cdot TP + \beta^2 \cdot FN + FP} = \frac{2TP}{2TP+FN+FP} $$

목차로


4. multi class dice loss와 class weights

  • 3번의 최종 score 계산을 보면 f-1score의 average를 계산하는 것을 볼 수 있으며 인자로 score, class_weights가 있는 것을 확인할 수 있다.
  • average의 코드를 확인하면 아래와 같다.
{python
def average(x, per_image=False, class_weights=None, **kwargs):
    backend = kwargs['backend']
    if per_image:
        x = backend.mean(x, axis=0)
    if class_weights is not None:
        x = x * class_weights
    return backend.mean(x)
  • score 값에 각 class_weights를 곱해서 전체 평균을 계산하여 반환하는 것을 볼 수 있다.
    • multi class segmentation의 경우 채널 별 f1-score를 계산한 후에 정해진 가중치를 곱해서 가중 평균을 계산하여 최종 loss를 도출하는 것으로 생각할 수 있다.
    • 따라서 class 별 가중치를 주지 않을 경우 채널별 f1-score의 단순 평균을 계산하게 되며 per image가 True 인 경우 사진 별로 per image가 False인 경우에는 batch 별로 많은 픽셀을 차지하는 class의 loss가 큰 비중을 차지하게 됨을 예상할 수 있다.
    • 모든 class가 dice loss에 가지는 비중을 동일하게 맞추려면 훈련 데이터 셋에서 해당 클래스가 차지하는 비중을 계산 후 그 역수를 class weight로 주는 것이 적절하다.
    • 특별히 loss에 비중을 가지게 하고 싶은 class가 있다면 그 class의 가중치를 높게 조절해 보는 것이 좋겠다.

목차로