이번 글을 딥러닝 분류 모델을 작성하던중 라벨 불균형 문제를 해결하기 위해 알게 된 Focal Loss에 대해 정리한 글입니다.

TL;DR

  • Focal Loss 는 2018년 Focal Loss for Dense Object Detection에서 제안된 손실함수로 분류 테스크에서 데이터 불균형을 해결하기 위해 제안되었습니다.
  • 데이터 불균형과 관련해서는 아래의 연구에서 데이터 비율이 10배를 넘어가면 분류 성능이 저하된다는 것을 실험적으로 증명하였습니다.

The Effect of Class Distribution on Classifier Learning:An Empirical Study

  • 이진 분류 문제에서 일반적으로 사용하는 Cross Entropy Loss는 잘못 예측된 데이터에 대해 큰 패널티를 주는게 목적인 함수입니다.
  • 일반적인 Cross Entropy는 데이터 불균형을 해결할 수 없었고, 이를 해결하기 위해 데이터의 개수에 따라 가중치를 주는 Balanced Cross Entropy Loss이 제안되었습니다.
  • 하지만 단순히 개수만으로 데이터 분류의 쉽고 어려움을 판단하기에는 무리가 있었고, 이를 해결하기 위해 제안된 것이 Focal Loss 입니다.
  • Focal Loss는 모델이 쉽게 분류할 수 있는 다수(majority, easy negative)에 대해 loss를 더 많이 낮춤으로서 다수 데이터의 영향이 누적되는 현상을 막고자 했습니다.

Focal Loss

  • 구체적인 수식에 대한 설명은 자세하게 설명된 글이 많기 때문에 생략하겠습니다.
  • 아래는 Cross Entropy Loss 와 Focal Loss 의 식입니다.

$$
CE = -log(p) \
FL = -(1-p)^{\gamma}log(p)
$$

  • 위 식을 그래프로 그린다면 아래와 같습니다.

  • 즉, $\gamma$ 파라미터를 조정하여 이미 학습이 잘되는(well-classified) 다수의 데이터에 대해서는 loss를 더 빠르게 낮춤으로서 누적되는 영향을 줄이는 것입니다.

Pytorch Implement

  • Pytorch 에서는 torchvision.ops.sigmoid_focal_loss 함수로 Focal Loss를 사용할 수 있습니다.

  • 소스 코드를 확인해보니 함수안에 sigmoid 함수가 포함되어 있기 때문에 입력값은 확률값이 아니라 모델 출력을 그대로 넣어주어야 합니다.

+ Recent posts